代码拉取完成,页面将自动刷新
"""
训练mnist分类器
"""
import os.path
import torch
import torchvision.transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from dataclasses import dataclass
from MnistClassifier import MnistClassifier
from tqdm import tqdm
import wandb
wandb.require("core")
from torch.nn.functional import softmax
import numpy as np
from Config import get_diffusion, ConfigMnist
def getRandomT(size):
# 将列表转换为NumPy数组
numbers_array = np.array([i for i in range(980, 0, -20)])
# 从NumPy数组中随机选择5个不重复的元素
selected_numbers_array = np.random.choice(numbers_array, size=size)
return torch.tensor(selected_numbers_array)
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
@dataclass
class Config():
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size_train = 128
self.batch_size_test = 1000
self.learning_rate = 0.01
self.mnist_root = "../MNIST"
self.save_root = "./log_classifier"
self.log_iter = int(60000 / self.batch_size_train) # 一轮一个log
self.save_checkpoint = int(60000 / self.batch_size_train) * 4 # 四轮一个保存
# 总训练轮数
self.epoch = 50
self.iterations = int(60000 / self.batch_size_train * self.epoch)
# wandb相关
self.wandb_project = "mnist_train"
self.wandb_name = "with noise"
# 模型定义相关
self.time_dim =128
self.drop_p = 0.5
self.type_act = "silu"
self.time_scale = 1.0
def cycle(dl):
while True:
for data in dl:
yield data
if __name__ == "__main__":
# 定义ddpm模型,只是为了方便添加噪声。
diffusion_model = get_diffusion(ConfigMnist()).requires_grad_(False)
config = Config()
wandb.init(project=config.wandb_project, name=config.wandb_name)
# 首先定义模型
model = MnistClassifier(config.time_dim, config.drop_p, config.type_act, config.time_scale)
model = model.to(config.device)
print("模型定义完毕")
if not os.path.exists(config.save_root):
os.mkdir(config.save_root)
# 拿到数据集,统一缩放到[-1, 1]区间
train_dataset = datasets.MNIST(
root=config.mnist_root,
train=True,
download=True,
transform=get_transform()
) # 60000条数据
test_dataset = datasets.MNIST(
root=config.mnist_root,
train=False,
download=True,
transform=get_transform()
) # 10000条数据
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size_train, shuffle=True)
train_dataloader = cycle(train_dataloader)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size_test, shuffle=True)
len_test = len(test_dataset)
print("数据集定义完毕")
# 定义优化器
opt = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate)
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
print("准备工作完成,开始训练")
# 开始训练
model = model.train()
loss_train = 0
for i in tqdm(range(config.iterations)):
# 训练一组样本
x, y = next(train_dataloader)
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_train += loss.item()
opt.zero_grad()
loss.backward()
opt.step()
if i % config.log_iter == 0:
loss_train /= config.log_iter # 计算出每一个batch的损失
print("{} / {} -->{}".format(i, config.iterations, loss_train))
wandb.log({"train_loss":loss_train})
loss_train = 0
if i % config.save_checkpoint == 0:
# 需要保存模型了
model_filename = os.path.join(config.save_root, "model_{}.pth".format(i // config.save_checkpoint))
opt_filename = os.path.join(config.save_root, "opt_{}.pth".format(i // config.save_checkpoint))
torch.save(model.state_dict(), model_filename)
torch.save(opt.state_dict(), opt_filename)
# 在测试集上进行评估
loss_test = 0
acc_test = 0
model = model.eval()
with torch.no_grad():
for x, y in test_dataloader:
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
# t = torch.randint(low=0, high=diffusion_model.T, size=(x.shape[0],), device=config.device) # 左闭右开区间
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_test += loss.item()
# 拿到准确率
pred_softmax = softmax(pred, dim=-1)
pred_y = pred_softmax.max(dim=-1)[1]
acc_test += (pred_y == y).sum().item()
loss_test /= len(test_dataloader)
acc_test /= len(test_dataset)
print("-"*20)
print("{} / {} loss-->{}, acc-->{}".format(i, config.iterations, loss_test, acc_test))
print("-"*20)
wandb.log({"test_loss":loss_test, "test_acc":acc_test})
model = model.train()
wandb.finish()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。