代码拉取完成,页面将自动刷新
同步操作将从 我没得冰阔落/nlp_machine_translation 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import copy
import os
import sys
import parsee
import torch
import torch.nn as nn
import torch.nn.functional as F
from prepare_data import PrepareData
from model.attention import MultiHeadedAttention
from model.position_wise_feedforward import PositionwiseFeedForward
from model.embedding import PositionalEncoding, Embeddings
from model.transformer import Transformer
from model.encoder import Encoder, EncoderLayer
from model.decoder import Decoder, DecoderLayer
from model.generator import Generator
from lib.criterion import LabelSmoothing
from lib.optimizer import NoamOpt
from train import train
from evaluate import evaluate
from auto_check import *
from tools import evaluate_cn
def make_model(src_vocab, tgt_vocab, N = 6, d_model = 512, d_ff = 2048, h = 8, dropout = 0.1):
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model).to(parsee.device)
ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(parsee.device)
position = PositionalEncoding(d_model, dropout).to(parsee.device)
model = Transformer(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(parsee.device), N).to(parsee.device),
Decoder(DecoderLayer(d_model, c(attn), c(attn),
c(ff), dropout).to(parsee.device), N).to(parsee.device),
nn.Sequential(Embeddings(d_model, src_vocab).to(parsee.device), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab).to(parsee.device), c(position)),
Generator(d_model, tgt_vocab)).to(parsee.device)
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model.to(parsee.device)
def translate(text, lang_out):
# 数据预处理
data = PrepareData()
parsee.src_vocab = len(data.en_word_dict)
parsee.tgt_vocab = len(data.cn_word_dict)
# 初始化模型
model = make_model(
parsee.src_vocab,
parsee.tgt_vocab,
parsee.layers,
parsee.d_model,
parsee.d_ff,
parsee.h_num,
parsee.dropout
)
translation = []
load = False
if os.path.exists(parsee.save_file):
# 加载模型
model.load_state_dict(torch.load(parsee.save_file))
load = True
else:
print("Error: pleas train before evaluate")
if load:
dictcn = data.en_word_dict
sentence, detail, lang = detect(text)
listq = []
if lang == lang_out:
print("same")
result = text
elif lang == 'en':
src = sentence.lower()
words = src.split()
for word in words:
e = data.en_word_dict[word]
listq.append(e)
translation = evaluate(data, model, listq)
result = "".join(translation)
elif lang == 'zh':
character_list = [char for char in sentence]
for char in sentence:
c = data.cn_word_dict[char]
listq.append(c)
translation = evaluate_cn(sentence, model, listq)
result = "".join(translation)
else:
result = text
if sentence == text:
sentence = ''
return result, sentence
if __name__ == "__main__":
text = "我是个男孩。"
translation, sentence = translate(text,'zh')
print(translation, sentence)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。