代码拉取完成,页面将自动刷新
from utils.my_tools import tokenizer
from utils.my_tools import generate
from utils.my_tools import output_path__ppo_model, device, output_path__gen_model
from utils.my_tools import TestModel
from utils.time_tools import with_timer
from bdtime import tt
import torch
use_ppo = 1
if use_ppo:
model_ppo = torch.load(output_path__ppo_model)
# model_ppo = model_ppo.to(device)
# model_ppo.eval()
# len(tokenizer.encoder)
model_gen = model_ppo.model_gen
else:
model_gen = torch.load(output_path__gen_model)
# 随机一批数据
# _, input_ids, _ = tokenizer.get_batch_data(prefix=True)
#
# #切分成question和answer
# split = [i.index(tokenizer.encoder['=']) + 1 for i in input_ids]
# question = [input_ids[i][:split[i]] for i in range(len(input_ids))]
# answer = [input_ids[i][split[i]:] for i in range(len(input_ids))]
batch_size = 100
label, _question, attention_mask, real_answer = TestModel.get_question(prefix=True, batch_size=batch_size, ret_real_answer=True)
# 根据question生成predict
with with_timer(f"生成predict, batch_size: {batch_size}", tt) as wt:
input_ids = [torch.LongTensor(i).unsqueeze(0).to(device) for i in _question]
_predict_qa = [generate(model_gen, i) for i in input_ids]
# 裁剪,只要生成的部分
_predict = [p[0].tolist()[len(q):] for p, q in zip(_predict_qa, _question)]
# tt.sleep(3)
# 解码成文本
# question = [tokenizer.decode(i) for i in question]
# answer = [tokenizer.decode(i) for i in real_answer]
# predict = [tokenizer.decode(i) for i in predict]
# from utils.my_tools import show_qap
# show_qap(_question, real_answer, _predict, end=10, skip_spacial_symbols=True)
flag__test_cls = 1
if flag__test_cls:
with with_timer(f"test_cls, batch_size: {batch_size}", tt) as wt:
from utils.my_tools import output_path__cls_model, test_predict_cls
# qa_token = [qa[0] for qa in _predict_qa]
# qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
qa_token = [torch.cat((q, a)) for q, a in list(zip(_question, real_answer))]
qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
# tokenizer.decode(qa_input_ids, '')
qa_input_ids = torch.LongTensor(qa_input_ids).to(device)
qa_attention_mask = torch.LongTensor(qa_attention_mask).to(device)
model_cls = torch.load(output_path__cls_model)
test_predict_cls(model_cls, qa_input_ids, qa_attention_mask, label)
question = tokenizer.decode(_question, '', skip_spacial_symbols=True)
answer = tokenizer.decode(real_answer, '', skip_spacial_symbols=True)
predict = tokenizer.decode(_predict, '', skip_spacial_symbols=True)
with with_timer(f"计算accuracy, batch_size: {batch_size}", tt) as wt:
acc = 0
show_times = 5
for i, q, a, p in zip(list(range(len(question))), question, answer, predict):
if i < show_times:
print(a == p, '--- q, a, p ---', q, a, p)
from utils.my_tools import test_model_cls
if a == p:
acc += 1
print('--- accuracy:', round(acc / len(question), 3), f"total_types: {tokenizer.total_types}")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。