Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
train_classifier.py 6.12 KB
Copy Edit Raw Blame History
"""
训练mnist分类器
"""
import os.path
import torch
import torchvision.transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from dataclasses import dataclass
from MnistClassifier import MnistClassifier
from tqdm import tqdm
import wandb
wandb.require("core")
from torch.nn.functional import softmax
import numpy as np
from Config import get_diffusion, ConfigMnist
def getRandomT(size):
# 将列表转换为NumPy数组
numbers_array = np.array([i for i in range(980, 0, -20)])
# 从NumPy数组中随机选择5个不重复的元素
selected_numbers_array = np.random.choice(numbers_array, size=size)
return torch.tensor(selected_numbers_array)
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
@dataclass
class Config():
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size_train = 128
self.batch_size_test = 1000
self.learning_rate = 0.01
self.mnist_root = "../MNIST"
self.save_root = "./log_classifier"
self.log_iter = int(60000 / self.batch_size_train) # 一轮一个log
self.save_checkpoint = int(60000 / self.batch_size_train) * 4 # 四轮一个保存
# 总训练轮数
self.epoch = 50
self.iterations = int(60000 / self.batch_size_train * self.epoch)
# wandb相关
self.wandb_project = "mnist_train"
self.wandb_name = "with noise"
# 模型定义相关
self.time_dim =128
self.drop_p = 0.5
self.type_act = "silu"
self.time_scale = 1.0
def cycle(dl):
while True:
for data in dl:
yield data
if __name__ == "__main__":
# 定义ddpm模型,只是为了方便添加噪声。
diffusion_model = get_diffusion(ConfigMnist()).requires_grad_(False)
config = Config()
wandb.init(project=config.wandb_project, name=config.wandb_name)
# 首先定义模型
model = MnistClassifier(config.time_dim, config.drop_p, config.type_act, config.time_scale)
model = model.to(config.device)
print("模型定义完毕")
if not os.path.exists(config.save_root):
os.mkdir(config.save_root)
# 拿到数据集,统一缩放到[-1, 1]区间
train_dataset = datasets.MNIST(
root=config.mnist_root,
train=True,
download=True,
transform=get_transform()
) # 60000条数据
test_dataset = datasets.MNIST(
root=config.mnist_root,
train=False,
download=True,
transform=get_transform()
) # 10000条数据
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size_train, shuffle=True)
train_dataloader = cycle(train_dataloader)
test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size_test, shuffle=True)
len_test = len(test_dataset)
print("数据集定义完毕")
# 定义优化器
opt = torch.optim.Adam(params=model.parameters(), lr=config.learning_rate)
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
print("准备工作完成,开始训练")
# 开始训练
model = model.train()
loss_train = 0
for i in tqdm(range(config.iterations)):
# 训练一组样本
x, y = next(train_dataloader)
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_train += loss.item()
opt.zero_grad()
loss.backward()
opt.step()
if i % config.log_iter == 0:
loss_train /= config.log_iter # 计算出每一个batch的损失
print("{} / {} -->{}".format(i, config.iterations, loss_train))
wandb.log({"train_loss":loss_train})
loss_train = 0
if i % config.save_checkpoint == 0:
# 需要保存模型了
model_filename = os.path.join(config.save_root, "model_{}.pth".format(i // config.save_checkpoint))
opt_filename = os.path.join(config.save_root, "opt_{}.pth".format(i // config.save_checkpoint))
torch.save(model.state_dict(), model_filename)
torch.save(opt.state_dict(), opt_filename)
# 在测试集上进行评估
loss_test = 0
acc_test = 0
model = model.eval()
with torch.no_grad():
for x, y in test_dataloader:
x = x.to(config.device)
y = y.to(config.device)
# 拿到x之后,需要对x添加噪声
# 拿到时间t
# t = torch.randint(low=0, high=diffusion_model.T, size=(x.shape[0],), device=config.device) # 左闭右开区间
t = getRandomT(x.shape[0]).to(config.device)
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=config.device)
# 拿到加了噪声的x
x = diffusion_model.noiseX(x, noise, t)
pred = model(x, t)
loss = loss_fn(pred, y)
loss_test += loss.item()
# 拿到准确率
pred_softmax = softmax(pred, dim=-1)
pred_y = pred_softmax.max(dim=-1)[1]
acc_test += (pred_y == y).sum().item()
loss_test /= len(test_dataloader)
acc_test /= len(test_dataset)
print("-"*20)
print("{} / {} loss-->{}, acc-->{}".format(i, config.iterations, loss_test, acc_test))
print("-"*20)
wandb.log({"test_loss":loss_test, "test_acc":acc_test})
model = model.train()
wandb.finish()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化