加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train_classifier.py 6.12 KB
一键复制 编辑 原始数据 按行查看 历史
"""
训练mnist分类器
"""
import os.path
import torch
import torchvision.transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from dataclasses import dataclass
from MnistClassifier import MnistClassifier
from tqdm import tqdm
import wandb
wandb.require("core")
from torch.nn.functional import softmax
import numpy as np
from Config import get_diffusion, ConfigMnist
def getRandomT(size):
# 将列表转换为NumPy数组
numbers_array = np.array([i for i in range(980, 0, -20)])
# 从NumPy数组中随机选择5个不重复的元素
selected_numbers_array = np.random.choice(numbers_array, size=size)
return torch.tensor(selected_numbers_array)
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
@dataclass
class Config():
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size_train = 128
self.batch_size_test = 1000
self.learning_rate = 0.01
self.mnist_root = "../MNIST"
self.save_root = "./log_classifier"
self.log_iter = int(60000 / self.batch_size_train) # 一轮一个log
self.save_checkpoint = int(60000 / self.batch_size_train) * 4 # 四轮一个保存
# 总训练轮数
self.epoch = 50
self.iterations = int(60000 / self.batch_size_train * self.epoch)
# wandb相关
self.wandb_project = "mnist_train"
self.wandb_name = "with noise"
# 模型定义相关
self.time_dim =128
self.drop_p = 0.5
self.type_act = "silu"
self.time_scale = 1.0
def cycle(dl):
while True:
for data in dl:
yield data
if __name__ == "__main__":
# 定义ddpm模型,只是为了方便添加噪声。
diffusion_model = get_diffusion(ConfigMnist()).requires_grad_(False)
config = Config()
wandb.init(project=config.wandb_project, name=config.wandb_name)
# 首先定义模型
model = MnistClassifier(config.time_dim, config.drop_p, config.type_act, config.time_scale)
model = model.to(config.device)
print("模型定义完毕")
if not os.path.exists(config.save_root):
os.mkdir(config.save_root)
# 拿到数据集,统一缩放到[-1, 1]区间
train_dataset = datasets.MNIST(
root=config.mnist_root,
train=True,
download=True,
transform=get_transform()
) # 60000条数据
test_dataset = datasets.MNIST(
root=config.mnist_root,
train=False,
download=True,
transform=get_transform()
) # 10000条数据
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size_train, shuffle=True)
train_dataloader = cycle(train_dataloader)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size_test, shuffle=True)
len_test = len(test_dataset)
print("数据集定义完毕")
# 定义优化器
opt = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate)
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
print("准备工作完成,开始训练")
# 开始训练
model = model.train()
loss_train = 0
for i in tqdm(range(config.iterations)):
# 训练一组样本
x, y = next(train_dataloader)
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_train += loss.item()
opt.zero_grad()
loss.backward()
opt.step()
if i % config.log_iter == 0:
loss_train /= config.log_iter # 计算出每一个batch的损失
print("{} / {} -->{}".format(i, config.iterations, loss_train))
wandb.log({"train_loss":loss_train})
loss_train = 0
if i % config.save_checkpoint == 0:
# 需要保存模型了
model_filename = os.path.join(config.save_root, "model_{}.pth".format(i // config.save_checkpoint))
opt_filename = os.path.join(config.save_root, "opt_{}.pth".format(i // config.save_checkpoint))
torch.save(model.state_dict(), model_filename)
torch.save(opt.state_dict(), opt_filename)
# 在测试集上进行评估
loss_test = 0
acc_test = 0
model = model.eval()
with torch.no_grad():
for x, y in test_dataloader:
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
# t = torch.randint(low=0, high=diffusion_model.T, size=(x.shape[0],), device=config.device) # 左闭右开区间
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_test += loss.item()
# 拿到准确率
pred_softmax = softmax(pred, dim=-1)
pred_y = pred_softmax.max(dim=-1)[1]
acc_test += (pred_y == y).sum().item()
loss_test /= len(test_dataloader)
acc_test /= len(test_dataset)
print("-"*20)
print("{} / {} loss-->{}, acc-->{}".format(i, config.iterations, loss_test, acc_test))
print("-"*20)
wandb.log({"test_loss":loss_test, "test_acc":acc_test})
model = model.train()
wandb.finish()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化