代码拉取完成,页面将自动刷新
"""
在这里定义属于自己的Unet
"""
import torch
from torch import nn
import math
def getAct(type_act):
if type_act == "relu":
return nn.ReLU
elif type_act == "sigmoid":
return nn.Sigmoid
elif type_act == "silu":
return nn.SiLU
else:
print("请输入正确的激活函数")
assert (False)
# 将输入的时间t变成一个embedding
# 似乎是通过硬编码的方式。不是重点,跳过
class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Input:
x: tensor of shape (N)
Output:
tensor of shape (N, dim)
Args:
dim (int): embedding dimension
scale (float): linear scale to be applied to timesteps. Default: 1.0
"""
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 残差连接模块
class ResModule(nn.Module):
def __init__(self, in_channel, out_channel, time_dim, drop_p = 0.5, type_act="relu"):
super(ResModule, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.time_dim = time_dim
self.act = getAct(type_act)
self.drop_p = drop_p
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.time_handle = nn.Sequential(
nn.Linear(self.time_dim, self.out_channel),
nn.Dropout(self.drop_p),
self.act()
)
self.res = None # 这里需要保证能够最后进行残差相加
if self.in_channel == self.out_channel:
self.res = nn.Identity()
else:
self.res = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
def forward(self, x, time_embedding):
out = self.conv1(x)
time_embedding = self.time_handle(time_embedding)
time_embedding = time_embedding[:, :, None, None]
out = out + time_embedding
out = self.conv2(out)
x = self.res(x)
out = out + x # 这里不能使用out+=x,因为这事inplace操作。
return out
# 下采样模块
# 输入的不管channel,只管h和w
class DownModule(nn.Module):
def __init__(self, channel, type_act="relu"):
super(DownModule, self).__init__()
self.channel = channel
self.act = getAct(type_act)
self.conv = nn.Sequential(
nn.Conv2d(self.channel, self.channel, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(self.channel),
self.act()
)
def forward(self, x):
return self.conv(x)
# 上采样模块
# 这次不止要管h和w,而且要将通道数减半
class UpModule(nn.Module):
def __init__(self, channel, type_act="relu"):
super(UpModule, self).__init__()
self.channel = channel
self.act = getAct(type_act)
self.conv = nn.Sequential(
# 首先要上采样
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(self.channel, self.channel // 2, kernel_size=3, padding=1), # 这里必须使用整除
nn.BatchNorm2d(self.channel // 2),
self.act()
)
def forward(self, x):
return self.conv(x)
class Unet(nn.Module):
def __init__(self, channel, time_dim, hidden_channel_list, drop_p = 0.5, type_act="relu", time_emb_scale=1.0):
super(Unet, self).__init__()
self.channel = channel
self.time_dim = time_dim
self.drop_p = drop_p
self.down_list = nn.ModuleList()
self.up_list = nn.ModuleList()
# 第一个有点不一样,不需要down
# 这里不能使用nn.Sequential打包,因为ResModule的输入不止一个,但是Down或Up模块的输入又只有一个
self.down_list.append(ResModule(channel, hidden_channel_list[0], time_dim, drop_p, type_act))
self.down_list.append(ResModule(hidden_channel_list[0], hidden_channel_list[0], time_dim, drop_p, type_act))
for i in range(1, len(hidden_channel_list)): # down, res, res模式
self.down_list.append(DownModule(hidden_channel_list[i-1], type_act))
self.down_list.append(ResModule(hidden_channel_list[i-1], hidden_channel_list[i], time_dim, drop_p, type_act))
self.down_list.append(ResModule(hidden_channel_list[i], hidden_channel_list[i], time_dim, drop_p, type_act))
# 开始上采样
self.up_list.append(UpModule(hidden_channel_list[-1], type_act))
for i in range(len(hidden_channel_list)-1, 0, -1): # res, res, up模式
self.up_list.append(ResModule(hidden_channel_list[i], hidden_channel_list[i-1], time_dim, drop_p, type_act))
self.up_list.append(ResModule(hidden_channel_list[i-1], hidden_channel_list[i-1], time_dim, drop_p, type_act))
self.up_list.append(UpModule(hidden_channel_list[i-1], type_act))
# 恢复通道数为3
self.recovery_channel = nn.Conv2d(hidden_channel_list[0], channel, kernel_size=1)
# 对时间进行编码
self.embed_time = PositionalEmbedding(time_dim, time_emb_scale)
# print(len(self.down_list))
# print(len(self.up_list))
def forward(self, x, t):
# 首先对时间进行编码
time_embedding = self.embed_time(t)
latent_list = []
# 下采样第一层
for i in range(2):
x = self.down_list[i](x, time_embedding)
latent_list.append(x)
# 下采样其他层
# for i in range(2, 14, 3):
for i in range(2, len(self.down_list), 3):
x = self.down_list[i](x)
x = self.down_list[i+1](x, time_embedding)
x = self.down_list[i+2](x, time_embedding)
latent_list.append(x)
# 上采样
j = -2
# print(x.shape)
# for i in range(0, 12, 3):
for i in range(0, len(self.up_list)-1, 3): # 最后一层上采样用不上
x = self.up_list[i](x)
# print(x.shape)
x = torch.cat((x, latent_list[j]), dim=1)
# print(x.shape)
j -= 1
x = self.up_list[i+1](x, time_embedding)
x = self.up_list[i+2](x, time_embedding)
x = self.recovery_channel(x)
return x
if __name__ == "__main__":
m = Unet(3, 16, [64, 128, 256, 512, 1024])
ret = m(torch.randn(8, 3, 32, 32), torch.arange(8))
print(ret.shape)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。