加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
searchServer.py 8.70 KB
一键复制 编辑 原始数据 按行查看 历史
王有平 提交于 2023-12-25 15:06 . '初始化'
# -*- coding: utf-8 -*-
import logging
import os
import threading
import time
import urllib
from glob import glob
from diskcache import Cache
from PIL import Image
from elasticsearch import Elasticsearch
from flask import Flask, request, render_template
from towhee.dc2 import pipe, ops
# -*- coding: utf-8 -*-
import config
import extractFeatures
import image_decode_custom
'''
以图搜图服务
'''
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
# es_host = os.environ.get("ES_HOST", config.elasticsearch_url)
# es_port = os.environ.get("ES_PORT", config.elasticsearch_port)
es_index = os.environ.get("ES_INDEX", config.elasticsearch_index)
es = Elasticsearch([{'host': config.elasticsearch_url, 'port': config.elasticsearch_port,
'http_auth': (config.es_name, config.es_password)}], timeout=3600)
image_decode_custom = image_decode_custom.ImageDecodeCV2()
last_upload_img = ""
# 日志
logger = logging.getLogger("log")
logger.setLevel(logging.DEBUG)
# 抽取路径
save_path = config.save_path
# 服务端口
server_port = config.server_port
# 相似度全局变量
global_similarity = 1
# 缓存对象
cache = Cache('./result_cache')
# es查询
def feature_search(query):
global es
body = {
"size": 30,
"query": {
"bool": {
"must": [
{
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, doc['feature'])",
"params": {
"queryVector": query[::8]
}
}
}
}
]
}
}
}
print(body)
results = es.search(
index=es_index,
body=body
)
hitCount = results['hits']['total']['value']
if hitCount > 0:
answers = []
max_score = results['hits']['max_score']
if max_score >= 0.35:
logger.info(
"max_score" + str(max_score) + "; global_similarity * max_score: " + str(global_similarity * max_score))
for hit in results['hits']['hits']:
if hit['_score'] >= global_similarity * max_score:
img_url = hit['_source']['url']
name = hit['_source']['name']
img_url = img_url.replace("#", "%23")
answers.append([img_url, name])
else:
answers = []
return answers
def search_es(img_path):
dc = p_search(img_path)
return dc
def add_catch_result(key, value, expire_time):
cache.add(key=key, value=value, expire=expire_time)
def get_catch_result(key):
result = cache.get(key=key)
return result
@app.route('/', methods=['GET', 'POST'])
def index():
return render_template('index.html')
# 搜索图片
@app.route('/search', methods=['GET', 'POST'])
def search():
if request.method == 'POST':
try:
file = request.files['query_img']
if file.filename == "":
return "请选择要上传的图片"
similarity = request.form.get("similarity")
flag = len(similarity) == 0 or 0 >= float(similarity) or float(similarity) > 1
if flag:
return "相似度范围为0~1"
# Save query image
img = Image.open(file.stream) # PIL image
# print(file.filename)
uploaded_img_path = save_path + "/" + file.filename
catch_key = file.filename + str(similarity)
result = get_catch_result(catch_key)
if result:
logger.info("有缓存-----")
return render_template('index.html',
query_path=urllib.parse.quote(uploaded_img_path),
scores=result)
# print(uploaded_img_path)
img.save(uploaded_img_path)
global global_similarity
global_similarity = float(similarity)
a = time.time()
# Run search
dc = p_search(uploaded_img_path)
c = time.time()
print('search file time:' + str(c - a))
# 得到查询结果
answers = dc.get()[0]
add_catch_result(catch_key, answers, 60)
# 删除上一次上传的图片
global last_upload_img
print(last_upload_img)
if last_upload_img is not None and len(last_upload_img) != 0:
if os.path.exists(last_upload_img):
os.remove(last_upload_img)
else:
print('删除上一次上传图片失败:', last_upload_img)
last_upload_img = uploaded_img_path
return render_template('index.html',
query_path=urllib.parse.quote(uploaded_img_path),
scores=answers)
except Exception as e:
logger.error("/search接口异常信息:" + str(e))
return "服务异常,请检查上传图片格式是否正确或者图片是否损坏!"
else:
return render_template('index.html')
@app.route('/searchImg', methods=['POST'])
def searchImg():
try:
if request.method == 'POST':
a = time.time()
file = request.files['query_img']
if file.filename == "":
return {'status': False, 'msg': "query_img参数为空", 'data': None}, 400
similarity = request.form.get("similarity")
if similarity is None:
return {'status': False, 'msg': "缺少similarity参数", 'data': None}, 400
flag = len(similarity) == 0 or 0 >= float(similarity) or float(similarity) > 1
if flag:
return "相似度范围为0~1"
# Save query image
img = Image.open(file.stream) # PIL image
# print(file.filename)
uploaded_img_path = "static/uploaded/" + file.filename
catch_key = file.filename + str(similarity)
result = get_catch_result(catch_key)
if result:
return {'status': True, 'msg': '查询成功!', 'data': result}, 200
# print(uploaded_img_path)
img.save(uploaded_img_path)
b = time.time()
logger.info('get file time:' + str(b - a))
global global_similarity
global_similarity = float(similarity)
# Run search
dc = search_es(uploaded_img_path)
c = time.time()
logger.info('search file time:' + str(c - a))
# 得到查询结果
answers = dc.get()[0]
# 删除上一次上传的图片
global last_upload_img
logger.info(last_upload_img)
if last_upload_img is not None and len(last_upload_img) != 0:
if os.path.exists(last_upload_img):
os.remove(last_upload_img)
else:
logger.info('删除上一次上传图片失败:', last_upload_img)
result = list()
for item in answers:
data = {}
img_path = item[0]
img = item[1]
identity = img.split("-")[0]
data["img_path"] = img_path
data["identity"] = identity
result.append(data)
# 缓存结果 半年
add_catch_result(key=catch_key, value=result, expire_time=15552000)
return {'status': True, 'msg': '查询成功!', 'data': result}, 200
except Exception as e:
logger.error(f"Get url error: {e}")
return {'status': False, 'msg': str(e.args), 'data': None}, 400
@app.get("/extract")
def extract():
threading.Thread(extractFeatures.start_extract()).start()
return "数据正在上载!"
# Load image path
def load_image(folderPath):
for filePath in glob(folderPath):
if os.path.splitext(filePath)[1] in config.types:
yield filePath
# Embedding pipeline
p_embed = (
pipe.input('src')
# 传入src,输出img_path
.flat_map('src', 'img_path', load_image)
# 传入img_path,输出img
.map('img_path', 'img', image_decode_custom)
# 传入img,输出vec
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
)
# Search pipeline
p_search_pre = (
p_embed.map('vec', 'search_res', feature_search)
)
# 输出 search_res
p_search = p_search_pre.output('search_res')
if __name__ == "__main__":
app.run("0.0.0.0", port=server_port, debug=True, threaded=True)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化