加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
sto.py 1.76 KB
一键复制 编辑 原始数据 按行查看 历史
from sto import StoDataset, NeuralNetwork
from torch.utils.data import random_split, DataLoader
from torch import nn, optim, no_grad
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 6 == 0 or (batch) * dataloader.batch_size + len(X) == size:
loss, current = loss.item(), (batch) * dataloader.batch_size + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss = 0
with no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f"Test Avg loss: {test_loss:>8f} \n")
def run_train():
dataset_object = StoDataset()
train_data_size = int(len(dataset_object) * 0.90)
test_data_size = len(dataset_object) - train_data_size
train_data, test_data = random_split(dataset_object, [train_data_size, test_data_size])
batch_size = 100
train_data_loader = DataLoader(train_data, batch_size=batch_size)
test_data_loader = DataLoader(test_data, batch_size=batch_size)
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_data_loader, model, loss_fn, optimizer)
test(test_data_loader, model, loss_fn)
print("Done!")
if __name__ == '__main__':
run_train()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化