Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
MnistClassifier.py 7.29 KB
Copy Edit Raw Blame History
jcy authored 2024-09-20 15:41 . 将Classifier加到ddim里面
"""
构建mnist的分类器
"""
import torch
from torch import nn
import math
def getAct(type_act):
if type_act == "relu":
return nn.ReLU
elif type_act == "sigmoid":
return nn.Sigmoid
elif type_act == "silu":
return nn.SiLU
else:
print("请输入正确的激活函数")
assert (False)
# 将输入的时间t变成一个embedding
# 似乎是通过硬编码的方式。不是重点,跳过
class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Input:
x: tensor of shape (N)
Output:
tensor of shape (N, dim)
Args:
dim (int): embedding dimension
scale (float): linear scale to be applied to timesteps. Default: 1.0
"""
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 残差连接模块
class ResModule(nn.Module):
def __init__(self, in_channel, out_channel, time_dim, drop_p = 0.5, type_act="relu"):
super(ResModule, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.time_dim = time_dim
self.act = getAct(type_act)
self.drop_p = drop_p
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.time_handle = nn.Sequential(
nn.Linear(self.time_dim, self.out_channel),
nn.Dropout(self.drop_p),
self.act()
)
self.res = None # 这里需要保证能够最后进行残差相加
if self.in_channel == self.out_channel:
self.res = nn.Identity()
else:
self.res = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
def forward(self, x, time_embedding):
out = self.conv1(x)
time_embedding = self.time_handle(time_embedding)
time_embedding = time_embedding[:, :, None, None]
out = out + time_embedding
out = self.conv2(out)
x = self.res(x)
out = out + x # 这里不能使用out+=x,因为这事inplace操作。
return out
# 下采样
class Down(nn.Module):
def __init__(self, channel):
super(Down, self).__init__()
self.conv = nn.Conv2d(channel, channel, kernel_size=4, padding=1, stride=2) # 保持channel数不变,但是大小减半
self.silu = nn.SiLU()
self.bn = nn.BatchNorm2d(channel)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.silu(x)
return x
# 通道数除以四,大小翻倍
class Up(nn.Module):
def __init__(self, channel):
super(Up, self).__init__()
self.conv = nn.Sequential(
# 首先要上采样
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(channel, channel // 4, kernel_size=3, padding=1), # 这里必须使用整除
nn.BatchNorm2d(channel // 4),
nn.SiLU()
)
def forward(self, x):
return self.conv(x)
class MnistClassifier(nn.Module):
def __init__(self, time_dim, drop_p = 0.5, type_act="silu", time_scale=1.0):
super(MnistClassifier, self).__init__()
# 就写一个最简单的吧,没有暴露任何可以调节的参数
# 第一个卷积,通道数1->4
self.d1_1 = ResModule(1, 4, time_dim, drop_p, type_act)
self.d1_2 = ResModule(4, 4, time_dim, drop_p, type_act)
self.d1_3 = nn.Sequential(
nn.Conv2d(4, 4, kernel_size=2, padding=2, stride=2),
nn.BatchNorm2d(4),
nn.SiLU()
)
# 第二个卷积,通道数4->16
self.d2_1 = ResModule(4, 16, time_dim, drop_p, type_act)
self.d2_2 = ResModule(16, 16, time_dim, drop_p, type_act)
self.d2_3 = Down(16)
# 第三个卷积,通道数16->64
self.d3_1 = ResModule(16, 64, time_dim, drop_p, type_act)
self.d3_2 = ResModule(64, 64, time_dim, drop_p, type_act)
self.d3_3 = Down(64)
self.d4_1 = ResModule(64, 256, time_dim, drop_p, type_act)
self.d4_2 = ResModule(256, 256, time_dim, drop_p, type_act)
self.d4_3 = Down(256)
self.m1 = ResModule(256, 1024, time_dim, drop_p, type_act)
self.m2 = ResModule(1024, 1024, time_dim, drop_p, type_act)
self.u1_1 = Up(1024)
self.u1_2 = ResModule(512, 256, time_dim, drop_p, type_act)
self.u1_3 = ResModule(256, 256, time_dim, drop_p, type_act)
self.u2_1 = Up(256)
self.u2_2 = ResModule(128, 64, time_dim, drop_p, type_act)
self.u2_3 = ResModule(64, 64, time_dim, drop_p, type_act)
self.u3_1 = Up(64)
self.u3_2 = ResModule(32, 16, time_dim, drop_p, type_act)
self.u3_3 = ResModule(16, 16, time_dim, drop_p, type_act)
self.u4_1 = ResModule(16, 10, time_dim, drop_p, type_act)
self.u4_2 = ResModule(10, 10, time_dim, drop_p, type_act)
# 使用GAP来替代全连接层
self.gap = nn.AdaptiveAvgPool2d(1)
# 时间编码
self.embed_time = PositionalEmbedding(time_dim, scale=time_scale)
def forward(self, x, t):
# 将t变成embedding
time_embedding = self.embed_time(t)
x = self.d1_1(x, time_embedding)
out1 = self.d1_2(x, time_embedding)
x = self.d1_3(out1)
x = self.d2_1(x, time_embedding)
out2 = self.d2_2(x, time_embedding)
x = self.d2_3(out2)
x = self.d3_1(x, time_embedding)
out3 = self.d3_2(x, time_embedding)
x = self.d3_3(out3)
x = self.d4_1(x, time_embedding)
out4 = self.d4_2(x, time_embedding)
x = self.d4_3(out4)
x = self.m1(x, time_embedding)
x = self.m2(x, time_embedding)
x = self.u1_1(x)
x = torch.cat((x, out4), dim=1)
x = self.u1_2(x, time_embedding)
x = self.u1_3(x, time_embedding)
x = self.u2_1(x)
x = torch.cat((x, out3), dim=1)
x = self.u2_2(x, time_embedding)
x = self.u2_3(x, time_embedding)
x = self.u3_1(x)
x = torch.cat((x, out2), dim=1)
x = self.u3_2(x, time_embedding)
x = self.u3_3(x, time_embedding)
x = self.u4_1(x, time_embedding)
x = self.u4_2(x, time_embedding)
x = self.gap(x)
x = x.squeeze()
return x
if __name__ == "__main__":
m = MnistClassifier(128)
x = torch.randn((5, 1, 28, 28))
t = torch.tensor([1, 2, 3, 4, 5])
# print(ret.shape)
# ret = m(x, t)
# print(ret.shape)
# loss_fn = nn.CrossEntropyLoss()
# label = torch.tensor([1, 2, 3, 4, 5])
# loss = loss_fn(ret, label)
# print(loss.item())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化