加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
translation.py 3.48 KB
一键复制 编辑 原始数据 按行查看 历史
我没得冰阔落 提交于 2023-07-04 09:13 . v1
import copy
import os
import sys
import parsee
import torch
import torch.nn as nn
import torch.nn.functional as F
from prepare_data import PrepareData
from model.attention import MultiHeadedAttention
from model.position_wise_feedforward import PositionwiseFeedForward
from model.embedding import PositionalEncoding, Embeddings
from model.transformer import Transformer
from model.encoder import Encoder, EncoderLayer
from model.decoder import Decoder, DecoderLayer
from model.generator import Generator
from lib.criterion import LabelSmoothing
from lib.optimizer import NoamOpt
from train import train
from evaluate import evaluate
from auto_check import *
from tools import evaluate_cn
def make_model(src_vocab, tgt_vocab, N = 6, d_model = 512, d_ff = 2048, h = 8, dropout = 0.1):
c = copy.deepcopy
attn = MultiHeadedAttention(h, d_model).to(parsee.device)
ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(parsee.device)
position = PositionalEncoding(d_model, dropout).to(parsee.device)
model = Transformer(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout).to(parsee.device), N).to(parsee.device),
Decoder(DecoderLayer(d_model, c(attn), c(attn),
c(ff), dropout).to(parsee.device), N).to(parsee.device),
nn.Sequential(Embeddings(d_model, src_vocab).to(parsee.device), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab).to(parsee.device), c(position)),
Generator(d_model, tgt_vocab)).to(parsee.device)
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model.to(parsee.device)
def translate(text, lang_out):
# 数据预处理
data = PrepareData()
parsee.src_vocab = len(data.en_word_dict)
parsee.tgt_vocab = len(data.cn_word_dict)
# 初始化模型
model = make_model(
parsee.src_vocab,
parsee.tgt_vocab,
parsee.layers,
parsee.d_model,
parsee.d_ff,
parsee.h_num,
parsee.dropout
)
translation = []
load = False
if os.path.exists(parsee.save_file):
# 加载模型
model.load_state_dict(torch.load(parsee.save_file))
load = True
else:
print("Error: pleas train before evaluate")
if load:
dictcn = data.en_word_dict
sentence, detail, lang = detect(text)
listq = []
if lang == lang_out:
print("same")
result = text
elif lang == 'en':
src = sentence.lower()
words = src.split()
for word in words:
e = data.en_word_dict[word]
listq.append(e)
translation = evaluate(data, model, listq)
result = "".join(translation)
elif lang == 'zh':
character_list = [char for char in sentence]
for char in sentence:
c = data.cn_word_dict[char]
listq.append(c)
translation = evaluate_cn(sentence, model, listq)
result = "".join(translation)
else:
result = text
if sentence == text:
sentence = ''
return result, sentence
if __name__ == "__main__":
text = "我是个男孩。"
translation, sentence = translate(text,'zh')
print(translation, sentence)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化