代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
import logging
import os
import threading
import time
import urllib
from glob import glob
from diskcache import Cache
from PIL import Image
from elasticsearch import Elasticsearch
from flask import Flask, request, render_template
from towhee.dc2 import pipe, ops
# -*- coding: utf-8 -*-
import config
import extractFeatures
import image_decode_custom
'''
以图搜图服务
'''
app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
# es_host = os.environ.get("ES_HOST", config.elasticsearch_url)
# es_port = os.environ.get("ES_PORT", config.elasticsearch_port)
es_index = os.environ.get("ES_INDEX", config.elasticsearch_index)
es = Elasticsearch([{'host': config.elasticsearch_url, 'port': config.elasticsearch_port,
'http_auth': (config.es_name, config.es_password)}], timeout=3600)
image_decode_custom = image_decode_custom.ImageDecodeCV2()
last_upload_img = ""
# 日志
logger = logging.getLogger("log")
logger.setLevel(logging.DEBUG)
# 抽取路径
save_path = config.save_path
# 服务端口
server_port = config.server_port
# 相似度全局变量
global_similarity = 1
# 缓存对象
cache = Cache('./result_cache')
# es查询
def feature_search(query):
global es
body = {
"size": 30,
"query": {
"bool": {
"must": [
{
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.queryVector, doc['feature'])",
"params": {
"queryVector": query[::8]
}
}
}
}
]
}
}
}
print(body)
results = es.search(
index=es_index,
body=body
)
hitCount = results['hits']['total']['value']
if hitCount > 0:
answers = []
max_score = results['hits']['max_score']
if max_score >= 0.35:
logger.info(
"max_score" + str(max_score) + "; global_similarity * max_score: " + str(global_similarity * max_score))
for hit in results['hits']['hits']:
if hit['_score'] >= global_similarity * max_score:
img_url = hit['_source']['url']
name = hit['_source']['name']
img_url = img_url.replace("#", "%23")
answers.append([img_url, name])
else:
answers = []
return answers
def search_es(img_path):
dc = p_search(img_path)
return dc
def add_catch_result(key, value, expire_time):
cache.add(key=key, value=value, expire=expire_time)
def get_catch_result(key):
result = cache.get(key=key)
return result
@app.route('/', methods=['GET', 'POST'])
def index():
return render_template('index.html')
# 搜索图片
@app.route('/search', methods=['GET', 'POST'])
def search():
if request.method == 'POST':
try:
file = request.files['query_img']
if file.filename == "":
return "请选择要上传的图片"
similarity = request.form.get("similarity")
flag = len(similarity) == 0 or 0 >= float(similarity) or float(similarity) > 1
if flag:
return "相似度范围为0~1"
# Save query image
img = Image.open(file.stream) # PIL image
# print(file.filename)
uploaded_img_path = save_path + "/" + file.filename
catch_key = file.filename + str(similarity)
result = get_catch_result(catch_key)
if result:
logger.info("有缓存-----")
return render_template('index.html',
query_path=urllib.parse.quote(uploaded_img_path),
scores=result)
# print(uploaded_img_path)
img.save(uploaded_img_path)
global global_similarity
global_similarity = float(similarity)
a = time.time()
# Run search
dc = p_search(uploaded_img_path)
c = time.time()
print('search file time:' + str(c - a))
# 得到查询结果
answers = dc.get()[0]
add_catch_result(catch_key, answers, 60)
# 删除上一次上传的图片
global last_upload_img
print(last_upload_img)
if last_upload_img is not None and len(last_upload_img) != 0:
if os.path.exists(last_upload_img):
os.remove(last_upload_img)
else:
print('删除上一次上传图片失败:', last_upload_img)
last_upload_img = uploaded_img_path
return render_template('index.html',
query_path=urllib.parse.quote(uploaded_img_path),
scores=answers)
except Exception as e:
logger.error("/search接口异常信息:" + str(e))
return "服务异常,请检查上传图片格式是否正确或者图片是否损坏!"
else:
return render_template('index.html')
@app.route('/searchImg', methods=['POST'])
def searchImg():
try:
if request.method == 'POST':
a = time.time()
file = request.files['query_img']
if file.filename == "":
return {'status': False, 'msg': "query_img参数为空", 'data': None}, 400
similarity = request.form.get("similarity")
if similarity is None:
return {'status': False, 'msg': "缺少similarity参数", 'data': None}, 400
flag = len(similarity) == 0 or 0 >= float(similarity) or float(similarity) > 1
if flag:
return "相似度范围为0~1"
# Save query image
img = Image.open(file.stream) # PIL image
# print(file.filename)
uploaded_img_path = "static/uploaded/" + file.filename
catch_key = file.filename + str(similarity)
result = get_catch_result(catch_key)
if result:
return {'status': True, 'msg': '查询成功!', 'data': result}, 200
# print(uploaded_img_path)
img.save(uploaded_img_path)
b = time.time()
logger.info('get file time:' + str(b - a))
global global_similarity
global_similarity = float(similarity)
# Run search
dc = search_es(uploaded_img_path)
c = time.time()
logger.info('search file time:' + str(c - a))
# 得到查询结果
answers = dc.get()[0]
# 删除上一次上传的图片
global last_upload_img
logger.info(last_upload_img)
if last_upload_img is not None and len(last_upload_img) != 0:
if os.path.exists(last_upload_img):
os.remove(last_upload_img)
else:
logger.info('删除上一次上传图片失败:', last_upload_img)
result = list()
for item in answers:
data = {}
img_path = item[0]
img = item[1]
identity = img.split("-")[0]
data["img_path"] = img_path
data["identity"] = identity
result.append(data)
# 缓存结果 半年
add_catch_result(key=catch_key, value=result, expire_time=15552000)
return {'status': True, 'msg': '查询成功!', 'data': result}, 200
except Exception as e:
logger.error(f"Get url error: {e}")
return {'status': False, 'msg': str(e.args), 'data': None}, 400
@app.get("/extract")
def extract():
threading.Thread(extractFeatures.start_extract()).start()
return "数据正在上载!"
# Load image path
def load_image(folderPath):
for filePath in glob(folderPath):
if os.path.splitext(filePath)[1] in config.types:
yield filePath
# Embedding pipeline
p_embed = (
pipe.input('src')
# 传入src,输出img_path
.flat_map('src', 'img_path', load_image)
# 传入img_path,输出img
.map('img_path', 'img', image_decode_custom)
# 传入img,输出vec
.map('img', 'vec', ops.image_embedding.timm(model_name='resnet50'))
)
# Search pipeline
p_search_pre = (
p_embed.map('vec', 'search_res', feature_search)
)
# 输出 search_res
p_search = p_search_pre.output('search_res')
if __name__ == "__main__":
app.run("0.0.0.0", port=server_port, debug=True, threaded=True)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。