代码拉取完成,页面将自动刷新
同步操作将从 我没得冰阔落/nlp_machine_translation 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import time
import torch
import parsee
from lib.loss import SimpleLossCompute
def run_epoch(data, model, loss_compute, epoch):
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
for i , batch in enumerate(data):
out = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
loss = loss_compute(out, batch.trg_y, batch.ntokens)
total_loss += loss
total_tokens += batch.ntokens
tokens += batch.ntokens
if i % 50 == 1:
elapsed = time.time() - start
print("Epoch %d Batch: %d Loss: %f Tokens per Sec: %fs" % (epoch, i - 1, loss / batch.ntokens, tokens / elapsed / 1000))
start = time.time()
tokens = 0
return total_loss / total_tokens
def train(data, model, criterion, optimizer):
for epoch in range(parsee.epochs):
model.train()
run_epoch(data.train_data, model, SimpleLossCompute(model.generator, criterion, optimizer), epoch)
model.eval()
print('>>>>> Evaluate')
loss = run_epoch(data.dev_data, model, SimpleLossCompute(model.generator, criterion, None), epoch)
print('<<<<< Evaluate loss: %f' % loss)
torch.save(model.state_dict(), parsee.save_file)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。