加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 5.21 KB
一键复制 编辑 原始数据 按行查看 历史
vipzgy 提交于 2017-11-19 21:53 . cnn residual
import random
import datetime
import argparse
import model
import train
import mydatasets
import torch
random.seed(66)
torch.manual_seed(66)
# all hyper parameters
parser = argparse.ArgumentParser("residual network for sequence classification")
parser.add_argument("--load-path", type=str, default=None, help="load model form you path")
parser.add_argument("--parameters-path", type=str, default=None, help="change parameters form text file")
parser.add_argument('--F1', action='store_true', default=False, help="get F1 number")
parser.add_argument('--test', action='store_true', default=False, help="whether test")
parser.add_argument("--pkl-path", type=str, default="./data/fine_grained_task", help="save embedding path which can be used, just for saving time")
parser.add_argument("--pkl-name", type=str, default="raw.clean.pkl", help="save embedding name which can be used, just for saving time")
parser.add_argument("--embedding-file", type=str, default="D:/AI/embedding&corpus/glove.840B.300d.txt", help="embedding path")
parser.add_argument("--embedding-name", type=str, default="glove.840B.300d.txt", help="update embedding save name")
parser.add_argument("--use-embedding", action="store_true", default=True, help="whether use embedding")
parser.add_argument("--train-file", type=str, default="./data/fine_grained_task/raw.clean.train", help="train file path")
parser.add_argument("--dev-file", type=str, default="./data/fine_grained_task/raw.clean.dev", help="development file path")
parser.add_argument("--test-file", type=str, default="./data/fine_grained_task/raw.clean.test", help="test file path")
parser.add_argument('--save-dir', type=str, default='snapshot', help="all train model save path")
parser.add_argument('--train-name', type=str, default='cnnresidual135', help="train model save name")
parser.add_argument("--cuda", action="store_true", default=False, help="whether use cuda")
parser.add_argument('--log-interval', type=int, default=1)
parser.add_argument('--test-interval', type=int, default=100)
parser.add_argument('--save-interval', type=int, default=100)
parser.add_argument('--epochs', type=int, default=8)
parser.add_argument("--batch-size", type=int, default=16, help="batch size")
parser.add_argument("--label-num", type=int, default=2, help="the number of labels")
parser.add_argument("--embedding-dim", type=int, default=300, help="use embedding dimension")
parser.add_argument("--embed-num", type=int, default=None, help="the number of embedding words")
parser.add_argument('--dropout-embed', type=float, default=0.5)
parser.add_argument('--dropout-rnn', type=float, default=0.5)
parser.add_argument('--hidden-size', type=int, default=100)
parser.add_argument('--kernel-num', type=int, default=100)
parser.add_argument('--kernel-size', type=str, default='1,3,5')
parser.add_argument('--residual-num', type=int, default=2, help="the number of each residual contain")
parser.add_argument('--residual-layers', type=int, default=1, help="the number of residual layers")
parser.add_argument("--optimizer", type=str, default="Adam", help="using which optimizer, such as SGD")
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--weight-decay', type=float, default=1e-8)
parser.add_argument('--lr-scheduler', type=str, default=None)
parser.add_argument('--clip-norm', type=str, default=None)
args = parser.parse_args()
# if args.file_path is not None, parameters will change by reading this file
if args.parameters_path is not None:
pass
# date processing
train_data = mydatasets.MyDatasets(args.train_file, args.label_num, args.pkl_path, args.pkl_name, args.embedding_file,
args.embedding_name, args.embedding_dim, args.batch_size)
dev_data = mydatasets.MyDatasets(args.dev_file, args.label_num, args.pkl_path, args.pkl_name, args.embedding_file,
args.embedding_name, args.embedding_dim, args.batch_size, train_data)
test_data = mydatasets.MyDatasets(args.test_file, args.label_num, args.pkl_path, args.pkl_name, args.embedding_file,
args.embedding_name, args.embedding_dim, args.batch_size, train_data)
# update hyper parameters
args.label_num = len(train_data.vocabulary_label.word2id)
args.embed_num = len(train_data.vocabulary_text.word2id)
args.cuda = args.cuda and torch.cuda.is_available()
args.kernel_size = [int(k) for k in args.kernel_size.split(',')]
print("\nParameters:")
for attr, value in args.__dict__.items():
print("\t{}={}".format(attr, value))
# model and cuda
m_model = None
if args.load_path is None:
m_model = model.ResidualCNN(args, train_data.embedding)
else:
print('\nLoading model from [%s]...' % args.snapshot)
try:
m_model = torch.load(args.load_path)
except:
print("Sorry, This snapshot doesn't exist.")
exit()
if args.cuda:
m_model.cuda()
# train and predict
torch.set_num_threads(1)
if args.F1:
pass
# train.getF1(args, m_model, test_data.iterator, train_data.vocabulary_label)
elif args.test:
pass
# train.test(args, m_model, test_data.iterator)
else:
print(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
train.train(args, m_model, train_data.iterator, test_data.iterator)
print(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化