加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
diffusion.py 8.69 KB
一键复制 编辑 原始数据 按行查看 历史
jcy 提交于 2024-09-20 17:30 . 用一个不太准分类器指导图像生成
"""
DDPM的全部过程,包括前向扩散和反向去噪
"""
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
def getGrad(x, t, y, model, device):
with torch.enable_grad():
x = x.requires_grad_(True)
ret = F.log_softmax(model(x, t), dim=1)
selected = ret.gather(1, torch.tensor(y).reshape(-1, 1).to(device))
return torch.autograd.grad(selected.sum(), x)[0]
def exract(a, index, device):
"""
从张量a中提取出index并变成sz的size
:param a:
:param index:
:param sz:
:return:
"""
return a.to(device).gather(-1, index.to(device)).reshape(-1, 1, 1, 1)
# 产生线性的区间,从[low, high],步数为T
def generate_linear_schedule(T, low, high):
return np.linspace(low, high, T)
class DDPM(nn.Module):
def __init__(self, model, belta, img_channel, img_h, img_w, device="cuda", type_loss="l1"):
super(DDPM, self).__init__()
assert type_loss=="l1" or type_loss=="l2"
self.loss_fn = None
if type_loss == "l1":
self.loss_fn = F.l1_loss
else:
self.loss_fn = F.mse_loss
self.model = model
self.belta = belta
self.channel = img_channel
self.h = img_h
self.w = img_w
self.T = len(belta) # 计算总步数
# 计算其他的关键参数
self.register_buffer("alpha", 1 - belta)
self.register_buffer("alpha_bar", torch.cumprod(self.alpha, dim=0))
self.register_buffer("one_sub_alpha_bar", 1 - self.alpha_bar)
self.register_buffer("sqrt_one_sub_alpah_bar", torch.sqrt(self.one_sub_alpha_bar))
self.register_buffer("sqrt_alpha_bar", torch.sqrt(self.alpha_bar))
self.register_buffer("sqrt_alpha", torch.sqrt(self.alpha))
self.remove_noise_coeff = self.belta / self.sqrt_one_sub_alpah_bar
self.sigma = torch.sqrt(self.belta)
self.device = device
# 给图片添加噪声。
def noiseX(self, x, epsilon, t):
"""
输入原图、噪声和噪声强度,输出加噪后的图片
:param x:
:param epsilon:
:param t:
:return:
"""
# 拿到两个系数
# para_1 = self.sqrt_alpha_bar.gather(-1, t).reshape(-1, 1, 1, 1)
para_1 = exract(self.sqrt_alpha_bar, t, self.device)
# para_2 = self.sqrt_one_sub_alpah_bar.gather(-1, t).reshape(-1, 1, 1, 1)
para_2 = exract(self.sqrt_one_sub_alpah_bar, t, self.device)
return para_1 * x + para_2 * epsilon
# 定义扩散过程
def forward(self, x):
b, c, h, w = x.shape
assert h==self.h and w==self.w
# 拿到时间t
t = torch.randint(low=0, high=self.T, size=(b,), device=self.device)# 左闭右开区间
# 首先拿到和x一样大小的噪声
noise = torch.randn(size=x.shape, device=self.device)
# 拿到加了噪声的x
x_noise = self.noiseX(x, noise, t)
# 放到网络去预测噪声
noise_pred = self.model(x_noise, t)
return self.loss_fn(noise_pred, noise)
# 我自己的采样过程
@torch.no_grad()
def sample_me(self, batch_size):
# 首先拿到初始的噪声
latent = torch.randn(size=(batch_size, self.channel, self.h, self.w), device=self.device)
# 开始进入循环
for i in range(self.T-1, -1, -1): # 扩散过程是[0, T-1],所以采样过程反过来
# 根据i拿到当前的t,因为所有batch的t是一样的
t = torch.tensor([i] * batch_size, device=self.device)
# 拿到均值
noise_pred = self.model(latent, t)
mju = (latent - exract(self.belta, t, self.device) / exract(self.sqrt_one_sub_alpah_bar, t, self.device) * noise_pred) / exract(self.sqrt_alpha, t, self.device)
# 加上方差
if i > 0:
#
sigma_2 = exract(self.belta, t, self.device) * (1 - exract(self.alpha_bar, t-1, self.device)) / (1 - exract(self.alpha_bar, t, self.device))
latent = mju + torch.sqrt(sigma_2) * torch.randn(size=mju.shape, device=self.device)
# print(i, " latent[0,0,0,0]-->", latent[0,0,0,0].item(), " sigma_2[0, 0, 0, 0]-->", torch.sqrt(sigma_2)[0,0,0,0].item())
# latent = mju + exract(self.sigma, t, "cuda") * torch.randn(size=mju.shape, device="cuda")
else:
latent = mju
return latent
@torch.no_grad()
# 这个求的是均值
def remove_noise(self, x, t, device="cuda"):
return (
(x - exract(self.remove_noise_coeff, t, device) * self.model(x, t)) / exract(self.sqrt_alpha, t, device)
)
# 作者的采样过程
@torch.no_grad()
def sample(self, batch_size, device="cuda"):
# 拿到初始
x = torch.randn(batch_size, self.channel, self.h, self.w, device=device)
for t in range(self.T - 1, -1, -1):
t_batch = torch.tensor([t], device=device).repeat(batch_size)
x = self.remove_noise(x, t_batch, device)
if t > 0:
x += exract(self.sigma, t_batch, device) * torch.randn(size=x.shape, device=device) # 重参数化,这里和公式没有对应上
return x.cpu().detach()
# ddim的采样过程
@torch.no_grad()
def sample_ddim(self, batch_size):
# 首先拿到初始的噪声
latent = torch.randn(size=(batch_size, self.channel, self.h, self.w), device=self.device)
# 开始进入循环
for i in range(980, 0, -20): # 扩散过程是[0, T-1],所以采样过程反过来
# 根据i拿到当前的t,因为所有batch的t是一样的
t = torch.tensor([i] * batch_size, device=self.device)
# 拿到均值
noise_pred = self.model(latent, t)
a = exract(self.sqrt_alpha_bar, t-20, self.device)
b = exract(self.sqrt_one_sub_alpah_bar, t, self.device) * noise_pred
c = exract(self.sqrt_alpha_bar, t, self.device)
d = exract(self.sqrt_one_sub_alpah_bar, t-20, self.device) * noise_pred
mju = a * (latent - b) / c + d
latent = mju # 直接不管方差了
return latent
# Classifier guidance
@torch.no_grad()
def sample_ddim_cg(self, batch_size, y, model):
# 首先拿到初始的噪声
latent = torch.randn(size=(batch_size, self.channel, self.h, self.w), device=self.device)
# 开始进入循环
for i in range(980, 0, -20): # 扩散过程是[0, T-1],所以采样过程反过来
# 根据i拿到当前的t,因为所有batch的t是一样的
t = torch.tensor([i] * batch_size, device=self.device)
# 拿到均值
noise_pred = self.model(latent, t)
grad = getGrad(latent, t, y, model, self.device)
noise_pred -= grad * exract(self.sqrt_one_sub_alpah_bar, t, self.device)
a = exract(self.sqrt_alpha_bar, t-20, self.device)
b = exract(self.sqrt_one_sub_alpah_bar, t, self.device) * noise_pred
c = exract(self.sqrt_alpha_bar, t, self.device)
d = exract(self.sqrt_one_sub_alpah_bar, t-20, self.device) * noise_pred
mju = a * (latent - b) / c + d
latent = mju # 直接不管方差了
return latent
# 定义一个classifier guidance
# @torch.no_grad()
# def sample_classifier(self, batch_size, classifier):
# # 首先拿到初始的噪声
# latent = torch.randn(size=(batch_size, self.channel, self.h, self.w), device=self.device)
#
# # 开始进入循环
# for i in range(self.T-1, -1, -1): # 扩散过程是[0, T-1],所以采样过程反过来
# # 根据i拿到当前的t,因为所有batch的t是一样的
# t = torch.tensor([i] * batch_size, device=self.device)
# # 拿到unet的结果
# noise_pred = self.model(latent, t)
# # 拿到
#
#
#
#
# mju = (latent - exract(self.belta, t, self.device) / exract(self.sqrt_one_sub_alpah_bar, t, self.device) * noise_pred) / exract(self.sqrt_alpha, t, self.device)
#
# # 加上方差
# if i > 0:
# #
# sigma_2 = exract(self.belta, t, self.device) * (1 - exract(self.alpha_bar, t-1, self.device)) / (1 - exract(self.alpha_bar, t, self.device))
# latent = mju + torch.sqrt(sigma_2) * torch.randn(size=mju.shape, device=self.device)
# else:
# latent = mju
#
# return latent
# def test(a, b):
# return a
#
# if __name__ == "__main__":
#
# ddpm = DDPM(test, torch.Tensor([1, 2, 3, 4, 5]), 3, 4, 5)
# print(ddpm.sample(10).shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化