代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file train.py
@brief
@details
@author Shivelino
@date 2023-12-23 19:10
@version 0.0.1
@par Copyright(c):
@par todo:
@par history:
"""
import os
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import os.path as op
import pickle as pkl
from nets import get_model
from utils import get_device, get_dataloader_mnist
def train(opt):
# get dataloader
trainloader, testloader = get_dataloader_mnist(opt.data_dir, opt.batch_size)
# init model
device = get_device()
print(f"Current Model: {opt.model}")
model = get_model(opt.model).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
losses, acces = [], []
for epoch in range(opt.n_epochs): # 设置迭代次数
# 训练模型
model.train() # 将模型设置为训练模式
running_loss = 0.0
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
losses.append(running_loss / len(trainloader))
print(f'Epoch {epoch + 1}, Training Loss: {running_loss / len(trainloader):.4f}')
# 评估模型
if opt.test_every_epoch:
model.eval() # 将模型设置为评估模式
correct = 0
total = 0
with torch.no_grad():
for i, data in enumerate(testloader):
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
acces.append(accuracy)
print(f'Accuracy on the test set: {accuracy * 100:.2f}%')
# save losses and accuracies
if opt.save_loss:
if not op.exists("out"):
os.mkdir("out")
pkl.dump(losses, open(f"out/losses_{opt.model}.pkl", "wb"))
if opt.test_every_epoch:
pkl.dump(acces, open(f"out/acces_{opt.model}.pkl", "wb"))
# 保存模型
if not op.exists("model"):
os.mkdir("model")
torch.save(model.state_dict(), f'model/model_{opt.model}.pth')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="lenet", help='model')
parser.add_argument('--data_dir', type=str, default="data", help='data directory')
parser.add_argument('--batch_size', type=int, default=128, help='size of the batches')
parser.add_argument('--n_epochs', type=int, default=25, help='number of epochs of training')
parser.add_argument('--test_every_epoch', type=bool, default=True, help='run test in every epoch')
parser.add_argument('--save_loss', type=bool, default=True, help='save loss in every epoch')
train(parser.parse_args())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。