代码拉取完成,页面将自动刷新
同步操作将从 我没得冰阔落/nlp_machine_translation 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
import numpy as np
import time
import parsee
from torch.autograd import Variable
from utils import subsequent_mask
def log(data, timestamp):
file = open(f'log/log-{timestamp}.txt', 'a')
file.write(data)
file.write('\n')
file.close()
def greedy_decode(model, src, src_mask, max_len, start_symbol):
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
for i in range(max_len-1):
out = model.decode(memory, src_mask,
Variable(ys),
Variable(subsequent_mask(ys.size(1))
.type_as(src.data)))
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.data[0]
ys = torch.cat([ys,
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
return ys
def evaluate(data , model, presrc = None):
timestamp = time.time()
if presrc:
src = torch.from_numpy(np.array(presrc)).long().to("cuda")
src = src.unsqueeze(0)
src_mask = (src != 0).unsqueeze(-2)
out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])
translation = []
for j in range(1, out.size(1)):
sym = data.cn_index_dict[out[0, j].item()]
if sym != 'EOS':
translation.append(sym)
else:
break
# print("translation: %s" % " ".join(translation))
else:
with torch.no_grad():
for i in range(len(data.dev_en)):
src = torch.from_numpy(np.array(data.dev_en[i])).long().to("cuda")
src = src.unsqueeze(0)
src_mask = (src != 0).unsqueeze(-2)
out = greedy_decode(model, src, src_mask, max_len = parsee.max_length, start_symbol = data.cn_word_dict["BOS"])
translation = []
for j in range(1, out.size(1)):
sym = data.cn_index_dict[out[0, j].item()]
if sym != 'EOS':
translation.append(sym)
else:
break
# print("translation: %s" % " ".join(translation))
return translation
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。