代码拉取完成,页面将自动刷新
# -*- coding: utf-8 -*-
import random
import pickle
import argparse
from driver.Model import *
from driver.Train import *
from driver.Config import *
from data.Vocab import *
from data.DataLoader import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
# random
torch.manual_seed(666)
torch.cuda.manual_seed(666)
random.seed(666)
np.random.seed(666)
# gpu
gpu = torch.cuda.is_available()
print("GPU available: ", gpu)
print("CuDNN: ", torch.backends.cudnn.enabled)
# parameters
argparser = argparse.ArgumentParser()
argparser.add_argument('--config_file', default='default.cfg')
argparser.add_argument('--thread', default=1, type=int, help='thread num')
argparser.add_argument('--use_cuda', action='store_true', default=False)
args, extra_args = argparser.parse_known_args()
config = Configurable(args.config_file, extra_args)
torch.set_num_threads(args.thread)
config.use_cuda = False
if gpu and args.use_cuda:
config.use_cuda = True
print("\nGPU using status: ", config.use_cuda)
# data vocab embedding
train_data = read_corpus(config.train_file)
dev_data = read_corpus(config.dev_file)
test_data = read_corpus(config.test_file)
vocab_src, vocab_tgt = create_vocabularies(train_data, config.vocab_size)
pickle.dump(vocab_src, open(config.save_src_vocab_path, 'wb'))
pickle.dump(vocab_tgt, open(config.save_tgt_vocab_path, 'wb'))
m_embedding = vocab_src.create_vocab_embs(config.embedding_file)
print("Sentence Number: #train = %d" % (len(train_data)))
# model train
if config.which_model == 'lstm':
m_model = LSTM(config, vocab_src.size, vocab_tgt.size, PAD, m_embedding)
else:
raise RuntimeError("Invalid optim method: " + config.which_model)
if config.use_cuda:
torch.backends.cudnn.enabled = True
m_model = m_model.cuda()
train(m_model, train_data, dev_data, vocab_src, vocab_tgt, config)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。