代码拉取完成,页面将自动刷新
from util import *
from scipy.sparse import csr_matrix
import pickle
from Model import nn_LSTM
from torch.optim.lr_scheduler import LambdaLR
import os
hidden_size = 256
seq_length = 25
root = r"模型保存/混合/"
with open(root + "X_train.pickle", 'rb') as handle:
X_train = pickle.load(handle)
with open(root + "y_train.pickle", 'rb') as handle:
y_train = pickle.load(handle)
with open(root + "chars.pickle", 'rb') as handle:
chars = pickle.load(handle)
with open(root + "vocab_size.pickle", 'rb') as handle:
vocab_size = pickle.load(handle)
rnn = nn_LSTM(vocab_size, hidden_size, vocab_size)
for batch in get_batch(X_train, y_train, seq_length):
X_batch, y_batch = batch
rnn.load_state_dict(torch.load(root +'save_model.pth'))
print(sample_chars(rnn, X_batch[0], rnn.initHidden_test(), chars, 200))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。