加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 852 Bytes
一键复制 编辑 原始数据 按行查看 历史
Mist.Wang 提交于 2020-05-23 20:35 . Add files via upload
from util import *
from scipy.sparse import csr_matrix
import pickle
from Model import nn_LSTM
from torch.optim.lr_scheduler import LambdaLR
import os
hidden_size = 256
seq_length = 25
root = r"模型保存/混合/"
with open(root + "X_train.pickle", 'rb') as handle:
X_train = pickle.load(handle)
with open(root + "y_train.pickle", 'rb') as handle:
y_train = pickle.load(handle)
with open(root + "chars.pickle", 'rb') as handle:
chars = pickle.load(handle)
with open(root + "vocab_size.pickle", 'rb') as handle:
vocab_size = pickle.load(handle)
rnn = nn_LSTM(vocab_size, hidden_size, vocab_size)
for batch in get_batch(X_train, y_train, seq_length):
X_batch, y_batch = batch
rnn.load_state_dict(torch.load(root +'save_model.pth'))
print(sample_chars(rnn, X_batch[0], rnn.initHidden_test(), chars, 200))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化