加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Model.py 748 Bytes
一键复制 编辑 原始数据 按行查看 历史
Mist.Wang 提交于 2020-05-23 20:35 . Add files via upload
import torch
import torch.nn as nn
class nn_LSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, X, hidden):
_, hidden = self.lstm(X, hidden)
output = self.out(hidden[0])
return output, hidden
def initHidden(self):
return (torch.zeros(1, 1, self.hidden_size).cuda(),
torch.zeros(1, 1, self.hidden_size).cuda())
def initHidden_test(self):
return (torch.zeros(1, 1, self.hidden_size),
torch.zeros(1, 1, self.hidden_size))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化