代码拉取完成,页面将自动刷新
from sto import StoDataset, NeuralNetwork
from torch.utils.data import random_split, DataLoader
from torch import nn, optim, no_grad
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 6 == 0 or (batch) * dataloader.batch_size + len(X) == size:
loss, current = loss.item(), (batch) * dataloader.batch_size + len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
model.eval()
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss = 0
with no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
test_loss /= num_batches
print(f"Test Avg loss: {test_loss:>8f} \n")
def run_train():
dataset_object = StoDataset()
train_data_size = int(len(dataset_object) * 0.90)
test_data_size = len(dataset_object) - train_data_size
train_data, test_data = random_split(dataset_object, [train_data_size, test_data_size])
batch_size = 100
train_data_loader = DataLoader(train_data, batch_size=batch_size)
test_data_loader = DataLoader(test_data, batch_size=batch_size)
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_data_loader, model, loss_fn, optimizer)
test(test_data_loader, model, loss_fn)
print("Done!")
if __name__ == '__main__':
run_train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。