加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
server.py 15.06 KB
一键复制 编辑 原始数据 按行查看 历史
xiaowei305 提交于 2022-10-30 17:21 . add filter
#_*_coding:utf8_*_
from flask import Flask
from flask import flash, session, redirect, request, url_for
from flask import render_template
from flask_login import LoginManager,login_user,login_required,current_user
from flask_wtf.file import FileField, FileRequired, FileAllowed
from flask_wtf import FlaskForm
from flask_sockets import Sockets
from flask_cors import CORS
from werkzeug.routing.rules import Rule
from wtforms import SubmitField, StringField, PasswordField, BooleanField
from wtforms.validators import Length,DataRequired,Optional
from geventwebsocket.exceptions import WebSocketError
import os
import argparse
import pickle
import uuid
import logging
import base64
import json
import sqlite3
import threading, queue
import asyncio
import time
import cv2
import numpy as np
from sys import argv
from inference import CenterFace, HSEmotionRecognizer, InsightFace
os.chdir(os.path.dirname(os.path.abspath(__file__)))
app = Flask(__name__)
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = "login"
app.config['MAX_CONTENT_LENGTH'] = 1 * 1024 * 1024 * 8
app.config['UPLOAD_PATH'] = os.path.join(app.root_path, 'uploads')
app.debug = True # 设置调试模式,生产模式的时候要关掉debug
app.config["SECRET_KEY"] = "11235813"
app.config['JSON_AS_ASCII'] = False
CORS(app, supports_credentials=True)
sockets = Sockets(app)
database_conn = sqlite3.connect('emotion.db')
im_width = 640
im_height = 360
im_width = 1280
im_height = 768
angle_thresh = 30 #30
centerface = CenterFace('models/centerface_int8.engine', im_height, im_width)
emotion_model = HSEmotionRecognizer(model_name='enet_b0_8_va_mtl')
recognition_model = InsightFace("models/recognition/insight_r34.engine")
face_dict = {}
FACE_DATA = "face.dat"
if os.path.exists(FACE_DATA):
face_dict = pickle.load(open(FACE_DATA, 'rb'))
class UploadForm(FlaskForm):
photo = FileField('Upload Image', validators=[FileRequired(), FileAllowed(['jpg','jpeg'])])
submit = SubmitField()
def random_filename(filename):
ext = os.path.splitext(filename)[1]
new_filename = uuid.uuid4().hex + ext
return new_filename
def check_face_num(boxes):
num = len(boxes)
ok = False
json_response = {"code":200}
if num == 0:
json_response = {"code":202,"msg":"无人脸或人脸不完整","data":{}}
elif num > 1:
json_response = {"code":203,"msg":"存在多个人脸","data":{}}
else:
ok = True
return ok, json_response
@app.route('/register', methods=['GET', 'POST'])
def register():
if request.method == 'POST':
f = request.files['file']
fname = f.filename
f.save(f.filename)
frame = cv2.imread(fname)
os.unlink(fname)
if frame is None:
return {"code":202,"msg":"图像错误,请上传jpg/png格式","data":{}}
dets, landmarks = centerface(frame)
ok, msg = check_face_num(dets)
if not ok:
return msg
pitch, roll, yaw = centerface.get_orientation(dets, landmarks)[0]
if np.abs(pitch) > angle_thresh or np.abs(roll) > angle_thresh or np.abs(yaw) > angle_thresh:
return {"code":204,"msg":"请上传正面人脸","data":[]}
features = recognition_model.get_features(frame, dets, landmarks)
face_dict[fname] = features.copy()
with open(FACE_DATA, "wb") as f:
pickle.dump(face_dict, f)
#for bbox in dets.astype('int32'):
# cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
#cv2.imwrite(fname, frame)
return {"code":200, "msg": "OK"}
else:
return render_template('upload.html', url=request.url_rule)
@app.route('/feature', methods=['GET', 'POST'])
def get_feature():
if request.method == 'POST':
f = request.files['file']
fname = f.filename
f.save(f.filename)
frame = cv2.imread(fname)
os.unlink(fname)
if frame is None:
return {"code":202,"msg":"图像错误,请上传jpg/png格式","data":{}}
dets, landmarks = centerface(frame)
ok, msg = check_face_num(dets)
if not ok:
return msg
pitch, roll, yaw = centerface.get_orientation(dets, landmarks)[0]
if np.abs(pitch) > angle_thresh or np.abs(roll) > angle_thresh or np.abs(yaw) > angle_thresh:
return {"code":204,"msg":"请上传正面人脸","data":[]}
features = recognition_model.get_features(frame, dets, landmarks)
#for bbox in dets.astype('int32'):
# cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2)
#cv2.imwrite(fname, frame)
return {"code":200,"msg":"OK","data":features[0].tolist()}
else:
return render_template('upload.html', url=request.url_rule)
@app.route('/video', methods=['GET', 'POST'])
def video():
return render_template('video.html')
@app.route('/recognition', methods=['GET', 'POST'])
def recognition():
if request.method == 'POST':
f = request.files['file']
fname = f.filename
f.save(fname)
frame = cv2.imread(fname)
os.unlink(fname)
if frame is None:
return {"code":202,"msg":"图像错误,请上传jpg/png格式","data":{}}
dets, landmarks = centerface(frame)
ok, msg = check_face_num(dets)
if not ok:
return msg
pitch, roll, yaw = centerface.get_orientation(dets, landmarks)[0]
if np.abs(pitch) > angle_thresh or np.abs(roll) > angle_thresh or np.abs(yaw) > angle_thresh:
return {"code":204,"msg":"请上传正面人脸","data":[]}
features = recognition_model.get_features(frame, dets, landmarks)
all_features = np.concatenate(list(face_dict.values()))
idx = recognition_model.compare(features[0], all_features)
if idx > -1:
data = list(face_dict.keys())[idx]
return {"code":200,"msg":"OK","data":data}
return {"code":202,"msg":"用户未注册","data":{}}
else:
return render_template('upload.html', url=request.url_rule)
def create_table():
c = database_conn.cursor()
c.execute('''create table if not exists emotions(
id integer PRIMARY KEY autoincrement,
recorddate datetime default (datetime()),
anger real,
contempt real,
disgust real,
fear real,
happiness real,
neutral real,
sadness real,
surprise real);''')
database_conn.commit()
def save_to_database(em_scores):
c = database_conn.cursor()
for em in em_scores:
c.execute('''insert into emotions (anger, contempt, disgust, fear, happiness, neutral, sadness, surprise)
values({}, {}, {}, {}, {}, {}, {}, {});'''.format(*em))
database_conn.commit()
@app.route('/clean_emotion', methods=['GET', 'POST'])
def clean_emotion():
c = database_conn.cursor()
c.execute("drop table emotions")
database_conn.commit()
create_table()
return {"code":200, "msg": "OK"}
@app.route('/emotion', methods=['GET', 'POST'])
def get_emotion():
if request.method == 'POST':
f = request.files['file']
fname = f.filename
f.save(fname)
frame = cv2.imread(fname)
os.unlink(fname)
if frame is None:
return {"code":202,"msg":"图像错误,请上传jpg/png格式","data":{}}
dets, _ = centerface(frame)
ok, msg = check_face_num(dets)
if not ok:
return msg
emotion, scores = emotion_model.get_emotions(frame, dets)
data = {"emotion": emotion, "scores": scores.tolist()}
return {"code":200,"msg":"OK","data":data}
else:
return render_template('upload.html', url=request.url_rule)
@app.route('/')
def index():
form = LoginForm()
return render_template('login.html',formid='loginForm',action='/login',method='post',form=form)
class LoginForm(FlaskForm):
username = StringField('账户名:', validators=[DataRequired(), Length(1, 30)])
password = PasswordField('密码:', validators=[DataRequired(), Length(1, 64)])
remember_me = BooleanField('记住密码', validators=[Optional()])
@login_manager.user_loader
def load_user(userid):
return User.get(userid)
@app.route('/login',methods=['GET','POST'])
def login():
form = LoginForm()
if form.validate_on_submit():
username = form.username.data
password = form.password.data
result = True
model = User.get(username)
if result:
login_user(model)
print()
print('登陆成功')
print(current_user.username)
print(current_user.is_authenticated())
return redirect('/upload')
else:
print('登陆失败')
return render_template('login.html',formid='loginForm',action='/login',method='post',form=form)
return render_template('login.html',formid='loginForm',action='/login',method='post',form=form)
@sockets.route('/echo')
def echo_socket(ws):
while not ws.closed:
message = ws.receive() # 接收到消息
if message is not None:
ws.send("hello," + message) # 主动给客户端推送消息
class EmotionThread(threading.Thread):
def __init__(self, msg_queue, ws, save_emotion=False):
super(EmotionThread, self).__init__()
self.queue = msg_queue
self.save_emotion = save_emotion
self.ws = ws
def run(self):
while True:
timestamp, frame, dets, lms = self.queue.get()
emotions, emo_score = emotion_model.get_emotions(frame, dets)
boxes = dets[:, :4].astype('int32').tolist()
assert len(boxes) == len(emotions)
if self.save_emotion:
save_to_database(emo_score)
# Filter side face.
angles = centerface.get_orientation(dets, lms)
filtered_boxes = []
filtered_emos = []
for box, emo, (pitch, roll, yaw) in zip(boxes, emotions, angles):
if np.abs(pitch) > angle_thresh or np.abs(roll) > angle_thresh or np.abs(yaw) > angle_thresh:
continue
filtered_boxes.append(box)
filtered_emos.append(emo)
boxes = filtered_boxes
emotions = filtered_emos
ret, frame = cv2.imencode('.jpg', frame)
data = {
"detection": { "boxes": boxes, "emotions": emotions},
"timestamp": timestamp
}
if self.ws.closed:
print("socket closed")
last_queue = self.queue
self.queue = None
last_queue.get(False)
break
else:
try:
self.ws.send(frame)
self.ws.send(json.dumps(data))
except WebSocketError:
print("socket closed")
last_queue = self.queue
self.queue = None
last_queue.get(False)
break
def show_video(websocket, skip_frame=False, save_emotion=False, use_thread=False):
print("start to show video..")
video_path = 'rtsp://admin:123456Admin@192.168.254.2:554/h264/ch1/main/av_stream'
# video_path = 'car.avi'
# video_path = '/dev/video0'
# video_path = '/home/oem/Downloads/2.mp4'
# video_path = '4k_people.mp4'
video_path = '/dev/video0'
# video_path = 'demo.mp4'
cap = cv2.VideoCapture(video_path)
cap.set(3, 1920)
cap.set(4, 1080)
fps = cap.get(cv2.CAP_PROP_FPS)
print("FPS", fps)
if not cap.isOpened():
raise ValueError("video open failed")
assert not (skip_frame and use_thread), "Do not use thread when using skip frame."
emotion_thread = None
if use_thread:
msg_queue = queue.Queue(2)
emotion_thread = EmotionThread(msg_queue, websocket, save_emotion)
emotion_thread.start()
frame_id = 0
while cap.isOpened():
success, frame = cap.read()
if success:
timestamp = time.time()
if not skip_frame or frame_id % 2 == 0:
skip = False
dets, lms = centerface(frame, threshold=0.35)
if use_thread:
msg_queue = emotion_thread.queue
if msg_queue is None:
cap.release()
break
msg_queue.put((timestamp, frame, dets, lms))
frame_id += 1
continue
emotions, emo_scores = emotion_model.get_emotions(frame, dets)
boxes = dets[:, :4].astype('int32').tolist()
assert len(boxes) == len(emotions)
print("num:", len(boxes))
if save_emotion:
save_to_database(emo_scores)
# Filter side face.
angles = centerface.get_orientation(dets, lms)
filtered_boxes = []
filtered_emos = []
for box, emo, (pitch, roll, yaw) in zip(boxes, emotions, angles):
if np.abs(pitch) > angle_thresh or np.abs(roll) > angle_thresh or np.abs(yaw) > angle_thresh:
continue
filtered_boxes.append(box)
filtered_emos.append(emo)
boxes = filtered_boxes
emotions = filtered_emos
else:
skip = True
boxes = []
emo = [[]]
ret, frame = cv2.imencode('.jpg', frame)
#image = base64.b64encode(frame).decode("utf-8")
data = {
"detection": { "boxes": boxes, "emotions": emotions},
"timestamp": timestamp
}
frame_id += 1
if websocket.closed:
break
else:
try:
websocket.send(frame)
if not skip:
websocket.send(json.dumps(data))
except WebSocketError:
print("socket closed")
break
# asyncio.sleep(1)
else:
# cap = cv2.VideoCapture(video_path)
raise ValueError("video open failed")
cap.release()
@sockets.route('/camera', websocket=True)
def camera_socket(ws):
while not ws.closed:
message = ws.receive()
if message is not None:
show_video(ws)
sockets.url_map.add(Rule('/camera', endpoint=camera_socket, websocket=True))
def start_flask(port):
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
server = pywsgi.WSGIServer(('0.0.0.0', port), app, handler_class=WebSocketHandler)
print('server start')
server.serve_forever()
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser(description="export to onnx")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
create_table()
start_flask(args.port)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化