Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
train.py 2.31 KB
Copy Edit Raw Blame History
15748 authored 2024-09-26 09:22 . 完整模型训练套路
import torch.optim
import torchvision
from torch import nn
# 引入自己的模型
import model
from model import Tudui
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
train_data = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集长度:{}".format(train_data_size))
print("测试数据集长度:{}".format(test_data_size))
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
tudui = Tudui()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
lr = 0.01
optimizer = torch.optim.SGD(tudui.parameters(), lr=lr)
# 设置训练网络的参数
# 训练次数
total_train_step = 0
# 测试次数
total_test_step = 0
# 训练轮数
epoch = 10
writer = SummaryWriter('./logs')
for i in range(epoch):
print("--------第{}轮--------".format(i+1))
# 训练模式
tudui.train()
for data in train_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 优化器
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 100 == 0:
print("loss:{}".format(loss.item()))
writer.add_scalar("train_loss", loss.item(), total_train_step)
total_test_loss = 0
total_acc = 0
# 测试模式
tudui.eval()
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
outputs = tudui(imgs)
accuracy = (outputs.argmax(1) == targets).sum()
total_acc += accuracy
loss = loss_fn(outputs, targets)
total_test_loss += loss
print(total_test_loss)
print(total_acc/test_data_size)
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_acc", total_acc/test_data_size, total_test_step)
total_test_step += 1
torch.save(tudui, "tudui_{}.pth".format(i))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化