Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
Config.py 2.70 KB
Copy Edit Raw Blame History
jcy authored 2024-09-13 11:53 . 调整unet结构,使之可以适应MNIST
from dataclasses import dataclass
from unet_mine import Unet
from diffusion import DDPM, generate_linear_schedule
import torch
# cifar
@dataclass
class ConfigCifar():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "unet_mine"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../cifar_train'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 3
self.img_h = 32
self.img_w = 32
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256, 512, 1024]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# MNIST
@dataclass
class ConfigMnist():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "silu_mnist"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../MNIST'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 1
self.img_h = 28
self.img_w = 28
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# 定义模型
def get_diffusion(config):
# 首先定义Unet网络
unet = Unet(config.img_channel, config.time_dim, config.hidden_channel_list, config.dropout, config.activation, config.time_emb_scale)
# 定义ddpm模型
belta = generate_linear_schedule(config.num_timesteps, config.schedule_low, config.schedule_high)
belta = torch.Tensor(belta)
ddpm = DDPM(unet, belta, config.img_channel, config.img_h, config.img_w, config.device, config.loss_type)
return ddpm
if __name__ == "__main__":
config = ConfigMnist()
model = get_diffusion(config).model
x = torch.randn(size=(5, 1, 28, 28))
t = torch.arange(5)
ret = model(x, t)
print(ret.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化