文件
Clone or Download
manage.py 6.26 KB
Copy Edit Raw Blame History
十指紧扣 authored 25 days ago . update
# -*- coding:utf8 -*-
"""
Created on 2019/9/26 17:18
@author: minc
"""
from flask import Flask, render_template
from flask.json import jsonify
from pyecharts import options as opts
from pyecharts.charts import Bar
from pyecharts.charts import Line
from random import randrange
import datetime,os
from flask import request
import time
import json
from sklearn import metrics
import numpy as np
import os
from utils import get_option
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
STATIC_DIR = os.path.join(BASE_DIR, 'static')
app = Flask(__name__)
data_file = os.environ['DATA_FILE']
label_file= os.environ['LABEL_FILE']
# 使用index01.html自定义模板文件
@app.route("/")
def show_demo01():
with open(data_file,'r',encoding='utf-8') as f:
origin_data = json.load(f)
with open(label_file, 'r') as f:
labels = f.read().strip().split('\n')
for i in range(len(origin_data)):
origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]
data = []
for i, origin_item in enumerate(origin_data):
item = {}
# if '主诉' in origin_item:
# text = f'主诉:{origin_item["主诉"]}\n现病史:{origin_item["现病史"]}\n既往史:{origin_item["既往史"]}'
# else:
# text = origin_item['doc']
text = ''
for key in origin_item:
if type(origin_item[key]) is str:
text += origin_item[key]
predict = '#'.join(origin_item['predict'])
if 'label' in origin_item:
label = '#'.join(origin_item['label'])
else:
label = '#'.join(origin_item['labels'])
if 'id' in origin_item:
item['id'] = origin_item['id']
else:
item['id'] = i
item['text'] = text
item['predict'] = predict
item['label'] = label
data.append(item)
id2labels = set()
for item in data:
labels = item['label'].split('#')
id2labels = id2labels | set(labels)
id2labels = list(id2labels)
y = np.zeros((len(data),len(id2labels)))
y_hat = np.zeros((len(data),len(id2labels)))
for i,item in enumerate(data):
labels = item['label'].split('#')
predicts = item['predict'].split('#')
for label in labels:
y[i][id2labels.index(label)] = 1
for predict in predicts:
if predict == '':
continue
if predict in id2labels:
y_hat[i][id2labels.index(predict)] = 1
raw_report = metrics.classification_report(y, y_hat, digits=4,target_names = id2labels)
# print(raw_report)
raw_report = raw_report.split('\n')[1:]
report = []
for raw_row in raw_report:
raw_row = raw_row.strip()
if raw_row == '':
continue
raw_row = raw_row.split(' ')
row = list(filter(lambda x:x!='',raw_row))
# print(row)
if len(row) < 5:
continue
item = {}
item['name'] = row[0]
item['precision'] = row[1]
item['recall'] = row[2]
item['f1_score'] = row[3]
item['support'] = row[4]
report.append(item)
return render_template("index.html",data = data,report=report)
@app.route("/item",methods=['GET'])
def get_emr_info():
emr_id = request.args.get('id')
with open(data_file,'r',encoding='utf-8') as f:
origin_data = json.load(f)
with open(label_file, 'r') as f:
labels = f.read().strip().split('\n')
for i in range(len(origin_data)):
origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]
emr = None
for i, origin_item in enumerate(origin_data):
if 'id' in origin_item and int(emr_id) != origin_item['id']:
continue
if 'id' not in origin_item and int(emr_id) != i:
continue
emr = []
for key,value in origin_item.items():
emr.append({"key":key,"value":value})
if 'nodes' in origin_item:
nodes = origin_item['nodes']
triples = []
for path in origin_item['paths']:
for i in range(0,len(path),2):
if len(path[i:i+3]) != 3:
continue
triples.append(path[i:i+3])
graph_option = get_option(nodes,triples)
return render_template("item.html",emr = emr ,graph_option = graph_option)
else:
graph_option = get_option([],[])
return render_template("item.html",emr = emr ,graph_option = graph_option)
@app.route('/search',methods=['GET'])
def search():
raw_query = request.args.get('query')
query_id = None
query_predict = None
query_label = None
queries = raw_query.split('##')
for query in queries:
value = query.split(':')[1]
if 'id' in query:
query_id = value
if 'predict' in query:
query_predict = value
if 'label' in query:
query_label = value
with open(data_file,'r',encoding='utf-8') as f:
origin_data = json.load(f)
with open(label_file, 'r') as f:
labels = f.read().strip().split('\n')
for i in range(len(origin_data)):
origin_data[i]['labels'] = [label for label in labels if label in origin_data[i]['labels']]
data = []
for i, origin_item in enumerate(origin_data):
item = {}
text = ''
for key in origin_item:
if type(origin_item[key]) is str:
text += origin_item[key]
predict = '#'.join(origin_item['predict'])
label = '#'.join(origin_item['labels'])
item['id'] = origin_item.get('id', i)
item['text'] = text
item['predict'] = predict
item['label'] = label
flag = 0
if query_id is None or query_id == item['id']:
flag += 1
if query_predict is None or query_predict in item['predict']:
flag += 1
if query_label is None or query_label in item['label']:
flag += 1
if flag == 3:
data.append(item)
return jsonify({"msg":"提交成功","code":200,"data":data})
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=8004)
# print(BASE_DIR)
# print(randrange(50, 80))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化