代码拉取完成,页面将自动刷新
"""
构建mnist的分类器
"""
import torch
from torch import nn
import math
def getAct(type_act):
if type_act == "relu":
return nn.ReLU
elif type_act == "sigmoid":
return nn.Sigmoid
elif type_act == "silu":
return nn.SiLU
else:
print("请输入正确的激活函数")
assert (False)
# 将输入的时间t变成一个embedding
# 似乎是通过硬编码的方式。不是重点,跳过
class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Input:
x: tensor of shape (N)
Output:
tensor of shape (N, dim)
Args:
dim (int): embedding dimension
scale (float): linear scale to be applied to timesteps. Default: 1.0
"""
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# 残差连接模块
class ResModule(nn.Module):
def __init__(self, in_channel, out_channel, time_dim, drop_p = 0.5, type_act="relu"):
super(ResModule, self).__init__()
self.in_channel = in_channel
self.out_channel = out_channel
self.time_dim = time_dim
self.act = getAct(type_act)
self.drop_p = drop_p
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.conv2 = nn.Sequential(
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
self.time_handle = nn.Sequential(
nn.Linear(self.time_dim, self.out_channel),
nn.Dropout(self.drop_p),
self.act()
)
self.res = None # 这里需要保证能够最后进行残差相加
if self.in_channel == self.out_channel:
self.res = nn.Identity()
else:
self.res = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel),
self.act()
)
def forward(self, x, time_embedding):
out = self.conv1(x)
time_embedding = self.time_handle(time_embedding)
time_embedding = time_embedding[:, :, None, None]
out = out + time_embedding
out = self.conv2(out)
x = self.res(x)
out = out + x # 这里不能使用out+=x,因为这事inplace操作。
return out
# 下采样
class Down(nn.Module):
def __init__(self, channel):
super(Down, self).__init__()
self.conv = nn.Conv2d(channel, channel, kernel_size=4, padding=1, stride=2) # 保持channel数不变,但是大小减半
self.silu = nn.SiLU()
self.bn = nn.BatchNorm2d(channel)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.silu(x)
return x
# 通道数除以四,大小翻倍
class Up(nn.Module):
def __init__(self, channel):
super(Up, self).__init__()
self.conv = nn.Sequential(
# 首先要上采样
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(channel, channel // 4, kernel_size=3, padding=1), # 这里必须使用整除
nn.BatchNorm2d(channel // 4),
nn.SiLU()
)
def forward(self, x):
return self.conv(x)
class MnistClassifier(nn.Module):
def __init__(self, time_dim, drop_p = 0.5, type_act="silu", time_scale=1.0):
super(MnistClassifier, self).__init__()
# 就写一个最简单的吧,没有暴露任何可以调节的参数
# 第一个卷积,通道数1->4
self.d1_1 = ResModule(1, 4, time_dim, drop_p, type_act)
self.d1_2 = ResModule(4, 4, time_dim, drop_p, type_act)
self.d1_3 = nn.Sequential(
nn.Conv2d(4, 4, kernel_size=2, padding=2, stride=2),
nn.BatchNorm2d(4),
nn.SiLU()
)
# 第二个卷积,通道数4->16
self.d2_1 = ResModule(4, 16, time_dim, drop_p, type_act)
self.d2_2 = ResModule(16, 16, time_dim, drop_p, type_act)
self.d2_3 = Down(16)
# 第三个卷积,通道数16->64
self.d3_1 = ResModule(16, 64, time_dim, drop_p, type_act)
self.d3_2 = ResModule(64, 64, time_dim, drop_p, type_act)
self.d3_3 = Down(64)
self.d4_1 = ResModule(64, 256, time_dim, drop_p, type_act)
self.d4_2 = ResModule(256, 256, time_dim, drop_p, type_act)
self.d4_3 = Down(256)
self.m1 = ResModule(256, 1024, time_dim, drop_p, type_act)
self.m2 = ResModule(1024, 1024, time_dim, drop_p, type_act)
self.u1_1 = Up(1024)
self.u1_2 = ResModule(512, 256, time_dim, drop_p, type_act)
self.u1_3 = ResModule(256, 256, time_dim, drop_p, type_act)
self.u2_1 = Up(256)
self.u2_2 = ResModule(128, 64, time_dim, drop_p, type_act)
self.u2_3 = ResModule(64, 64, time_dim, drop_p, type_act)
self.u3_1 = Up(64)
self.u3_2 = ResModule(32, 16, time_dim, drop_p, type_act)
self.u3_3 = ResModule(16, 16, time_dim, drop_p, type_act)
self.u4_1 = ResModule(16, 10, time_dim, drop_p, type_act)
self.u4_2 = ResModule(10, 10, time_dim, drop_p, type_act)
# 使用GAP来替代全连接层
self.gap = nn.AdaptiveAvgPool2d(1)
# 时间编码
self.embed_time = PositionalEmbedding(time_dim, scale=time_scale)
def forward(self, x, t):
# 将t变成embedding
time_embedding = self.embed_time(t)
x = self.d1_1(x, time_embedding)
out1 = self.d1_2(x, time_embedding)
x = self.d1_3(out1)
x = self.d2_1(x, time_embedding)
out2 = self.d2_2(x, time_embedding)
x = self.d2_3(out2)
x = self.d3_1(x, time_embedding)
out3 = self.d3_2(x, time_embedding)
x = self.d3_3(out3)
x = self.d4_1(x, time_embedding)
out4 = self.d4_2(x, time_embedding)
x = self.d4_3(out4)
x = self.m1(x, time_embedding)
x = self.m2(x, time_embedding)
x = self.u1_1(x)
x = torch.cat((x, out4), dim=1)
x = self.u1_2(x, time_embedding)
x = self.u1_3(x, time_embedding)
x = self.u2_1(x)
x = torch.cat((x, out3), dim=1)
x = self.u2_2(x, time_embedding)
x = self.u2_3(x, time_embedding)
x = self.u3_1(x)
x = torch.cat((x, out2), dim=1)
x = self.u3_2(x, time_embedding)
x = self.u3_3(x, time_embedding)
x = self.u4_1(x, time_embedding)
x = self.u4_2(x, time_embedding)
x = self.gap(x)
x = x.squeeze()
return x
if __name__ == "__main__":
m = MnistClassifier(128)
x = torch.randn((5, 1, 28, 28))
t = torch.tensor([1, 2, 3, 4, 5])
# print(ret.shape)
# ret = m(x, t)
# print(ret.shape)
# loss_fn = nn.CrossEntropyLoss()
# label = torch.tensor([1, 2, 3, 4, 5])
# loss = loss_fn(ret, label)
# print(loss.item())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。