加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
evaluate.py 2.30 KB
一键复制 编辑 原始数据 按行查看 历史
我没得冰阔落 提交于 2023-07-04 09:13 . v1
import torch
import numpy as np
import time
import parsee
from torch.autograd import Variable
from utils import subsequent_mask
def log(data, timestamp):
file = open(f'log/log-{timestamp}.txt', 'a')
file.write(data)
file.write('\n')
file.close()
def greedy_decode(model, src, src_mask, max_len, start_symbol):
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
for i in range(max_len-1):
out = model.decode(memory, src_mask,
Variable(ys),
Variable(subsequent_mask(ys.size(1))
.type_as(src.data)))
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.data[0]
ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
return ys
def evaluate(data , model, presrc = None):
timestamp = time.time()
if presrc:
src = torch.from_numpy(np.array(presrc)).long().to("cuda")
src = src.unsqueeze(0)
src_mask = (src != 0).unsqueeze(-2)
out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])
translation = []
for j in range(1, out.size(1)):
sym = data.cn_index_dict[out[0, j].item()]
if sym != 'EOS':
translation.append(sym)
else:
break
# print("translation: %s" % " ".join(translation))
else:
with torch.no_grad():
for i in range(len(data.dev_en)):
src = torch.from_numpy(np.array(data.dev_en[i])).long().to("cuda")
src = src.unsqueeze(0)
src_mask = (src != 0).unsqueeze(-2)
out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])
translation = []
for j in range(1, out.size(1)):
sym = data.cn_index_dict[out[0, j].item()]
if sym != 'EOS':
translation.append(sym)
else:
break
# print("translation: %s" % " ".join(translation))
return translation
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化