代码拉取完成,页面将自动刷新
from dataclasses import dataclass
from unet_mine import Unet
from diffusion import DDPM, generate_linear_schedule
import torch
# cifar
@dataclass
class ConfigCifar():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "unet_mine"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../cifar_train'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 3
self.img_h = 32
self.img_w = 32
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256, 512, 1024]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# MNIST
@dataclass
class ConfigMnist():
def __init__(self):
# wandb相关
self.wandb_project = "ddpm"
self.wandb_name = "silu_mnist"
# 训练相关超参数
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.learning_rate = 2e-4
self.batch_size = 128
self.dataset_root = '../MNIST'
self.iterations = 800000
self.log_rate=400
self.checkpoint_rate = 2000
self.log_dir = "./ddpm_logs"
# ddpm相关
self.schedule_low = 1e-4
self.schedule_high = 0.02
self.num_timesteps = 1000
self.loss_type = "l2"
# unet相关超参数
self.img_channel = 1
self.img_h = 28
self.img_w = 28
self.time_dim = 128
self.hidden_channel_list = [64, 128, 256]
self.dropout = 0.5
# self.activation = "relu"
self.activation = "silu"
self.time_emb_scale=1.0
# 定义模型
def get_diffusion(config):
# 首先定义Unet网络
unet = Unet(config.img_channel, config.time_dim, config.hidden_channel_list, config.dropout, config.activation, config.time_emb_scale)
# 定义ddpm模型
belta = generate_linear_schedule(config.num_timesteps, config.schedule_low, config.schedule_high)
belta = torch.Tensor(belta)
ddpm = DDPM(unet, belta, config.img_channel, config.img_h, config.img_w, config.device, config.loss_type)
return ddpm
if __name__ == "__main__":
config = ConfigMnist()
model = get_diffusion(config).model
x = torch.randn(size=(5, 1, 28, 28))
t = torch.arange(5)
ret = model(x, t)
print(ret.shape)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。