加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
f4_test.py 3.06 KB
一键复制 编辑 原始数据 按行查看 历史
bode135 提交于 2023-11-14 17:56 . test_ppo
from utils.my_tools import tokenizer
from utils.my_tools import generate
from utils.my_tools import output_path__ppo_model, device, output_path__gen_model
from utils.my_tools import TestModel
from utils.time_tools import with_timer
from bdtime import tt
import torch
use_ppo = 1
if use_ppo:
model_ppo = torch.load(output_path__ppo_model)
# model_ppo = model_ppo.to(device)
# model_ppo.eval()
# len(tokenizer.encoder)
model_gen = model_ppo.model_gen
else:
model_gen = torch.load(output_path__gen_model)
# 随机一批数据
# _, input_ids, _ = tokenizer.get_batch_data(prefix=True)
#
# #切分成question和answer
# split = [i.index(tokenizer.encoder['=']) + 1 for i in input_ids]
# question = [input_ids[i][:split[i]] for i in range(len(input_ids))]
# answer = [input_ids[i][split[i]:] for i in range(len(input_ids))]
batch_size = 100
label, _question, attention_mask, real_answer = TestModel.get_question(prefix=True, batch_size=batch_size, ret_real_answer=True)
# 根据question生成predict
with with_timer(f"生成predict, batch_size: {batch_size}", tt) as wt:
input_ids = [torch.LongTensor(i).unsqueeze(0).to(device) for i in _question]
_predict_qa = [generate(model_gen, i) for i in input_ids]
# 裁剪,只要生成的部分
_predict = [p[0].tolist()[len(q):] for p, q in zip(_predict_qa, _question)]
# tt.sleep(3)
# 解码成文本
# question = [tokenizer.decode(i) for i in question]
# answer = [tokenizer.decode(i) for i in real_answer]
# predict = [tokenizer.decode(i) for i in predict]
# from utils.my_tools import show_qap
# show_qap(_question, real_answer, _predict, end=10, skip_spacial_symbols=True)
flag__test_cls = 1
if flag__test_cls:
with with_timer(f"test_cls, batch_size: {batch_size}", tt) as wt:
from utils.my_tools import output_path__cls_model, test_predict_cls
# qa_token = [qa[0] for qa in _predict_qa]
# qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
qa_token = [torch.cat((q, a)) for q, a in list(zip(_question, real_answer))]
qa_input_ids, qa_attention_mask = tokenizer.batch_pad(token=qa_token)
# tokenizer.decode(qa_input_ids, '')
qa_input_ids = torch.LongTensor(qa_input_ids).to(device)
qa_attention_mask = torch.LongTensor(qa_attention_mask).to(device)
model_cls = torch.load(output_path__cls_model)
test_predict_cls(model_cls, qa_input_ids, qa_attention_mask, label)
question = tokenizer.decode(_question, '', skip_spacial_symbols=True)
answer = tokenizer.decode(real_answer, '', skip_spacial_symbols=True)
predict = tokenizer.decode(_predict, '', skip_spacial_symbols=True)
with with_timer(f"计算accuracy, batch_size: {batch_size}", tt) as wt:
acc = 0
show_times = 5
for i, q, a, p in zip(list(range(len(question))), question, answer, predict):
if i < show_times:
print(a == p, '--- q, a, p ---', q, a, p)
from utils.my_tools import test_model_cls
if a == p:
acc += 1
print('--- accuracy:', round(acc / len(question), 3), f"total_types: {tokenizer.total_types}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化