加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Config.py 2.70 KB
一键复制 编辑 原始数据 按行查看 历史
jcy 提交于 2024-09-13 11:53 . 调整unet结构,使之可以适应MNIST
from dataclasses import dataclass
from unet_mine import Unet
from diffusion import DDPM, generate_linear_schedule
import torch
# cifar
@dataclass
class ConfigCifar():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "unet_mine"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../cifar_train'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 3
self.img_h = 32
self.img_w = 32
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256, 512, 1024]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# MNIST
@dataclass
class ConfigMnist():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "silu_mnist"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../MNIST'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 1
self.img_h = 28
self.img_w = 28
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# 定义模型
def get_diffusion(config):
# 首先定义Unet网络
unet = Unet(config.img_channel, config.time_dim, config.hidden_channel_list, config.dropout, config.activation, config.time_emb_scale)
# 定义ddpm模型
belta = generate_linear_schedule(config.num_timesteps, config.schedule_low, config.schedule_high)
belta = torch.Tensor(belta)
ddpm = DDPM(unet, belta, config.img_channel, config.img_h, config.img_w, config.device, config.loss_type)
return ddpm
if __name__ == "__main__":
config = ConfigMnist()
model = get_diffusion(config).model
x = torch.randn(size=(5, 1, 28, 28))
t = torch.arange(5)
ret = model(x, t)
print(ret.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化