加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 1.98 KB
一键复制 编辑 原始数据 按行查看 历史
zhaoguangyao 提交于 2018-03-12 19:38 . all finish
# -*- coding: utf-8 -*-
import random
import pickle
import argparse
from driver.Model import *
from driver.Train import *
from driver.Config import *
from data.Vocab import *
from data.DataLoader import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
# random
torch.manual_seed(666)
torch.cuda.manual_seed(666)
random.seed(666)
np.random.seed(666)
# gpu
gpu = torch.cuda.is_available()
print("GPU available: ", gpu)
print("CuDNN: ", torch.backends.cudnn.enabled)
# parameters
argparser = argparse.ArgumentParser()
argparser.add_argument('--config_file', default='default.cfg')
argparser.add_argument('--thread', default=1, type=int, help='thread num')
argparser.add_argument('--use_cuda', action='store_true', default=False)
args, extra_args = argparser.parse_known_args()
config = Configurable(args.config_file, extra_args)
torch.set_num_threads(args.thread)
config.use_cuda = False
if gpu and args.use_cuda:
config.use_cuda = True
print("\nGPU using status: ", config.use_cuda)
# data vocab embedding
train_data = read_corpus(config.train_file)
dev_data = read_corpus(config.dev_file)
test_data = read_corpus(config.test_file)
vocab_src, vocab_tgt = create_vocabularies(train_data, config.vocab_size)
pickle.dump(vocab_src, open(config.save_src_vocab_path, 'wb'))
pickle.dump(vocab_tgt, open(config.save_tgt_vocab_path, 'wb'))
m_embedding = vocab_src.create_vocab_embs(config.embedding_file)
print("Sentence Number: #train = %d" % (len(train_data)))
# model train
if config.which_model == 'lstm':
m_model = LSTM(config, vocab_src.size, vocab_tgt.size, PAD, m_embedding)
else:
raise RuntimeError("Invalid optim method: " + config.which_model)
if config.use_cuda:
torch.backends.cudnn.enabled = True
m_model = m_model.cuda()
train(m_model, train_data, dev_data, vocab_src, vocab_tgt, config)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化