加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
my_model.py 435 Bytes
一键复制 编辑 原始数据 按行查看 历史
KunCheng-He 提交于 2023-06-14 15:40 . py file demo
"""
本文件包含该任务的模型
"""
from torch import nn
import torch
from config import *
class MyRnn(nn.Module):
""" 基于RNN来预测 """
def __init__(self, hidden_size: int) -> None:
super().__init__()
pass
def forward(self, x):
pass
return x
if __name__ == '__main__':
# 测试
model = MyRnn(100)
x = torch.rand((5, 3, 52))
y = model(x)
print(y.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化