加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
UNet.py 2.78 KB
一键复制 编辑 原始数据 按行查看 历史
CVHuber 提交于 2020-09-29 16:03 . init upload
import torch.nn as nn
import torch
import torch.nn.functional as F
class conv_block(nn.Module):
def __init__(self, in_ch, out_ch):
super(conv_block, self).__init__()
# 定义一个序列操作 卷积->BN->Relu->卷积->BN->Relu
self.conv = nn.Sequential(
# 参数分别为:输入通道数,输出通道数,卷积核大小以及填充大小
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
# 下采样步骤包含了最大池化和一个卷积块
self.max_pool_conv = nn.Sequential(
nn.MaxPool2d(2),
conv_block(in_ch, out_ch)
)
def forward(self, x):
x = self.max_pool_conv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch):
super(up, self).__init__()
# 上采样模块包括了转置卷积和一个卷积块
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = conv_block(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
# U-Net中的跳跃连接,将两个特征图拼接起来
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
# 输出块,二分类问题输出一个特征图即可,此时out_ch应该为1
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
def forward(self, x):
x = self.conv(x)
return x
class UNet(nn.Module):
def __init__(self, n_channels=3, n_classes=1):
super(UNet, self).__init__()
self.inc = conv_block(n_channels, 64)
self.down1 = down(64, 128)
self.down2 = down(128, 256)
self.down3 = down(256, 512)
self.down4 = down(512, 512)
self.up1 = up(1024, 256)
self.up2 = up(512, 128)
self.up3 = up(256, 64)
self.up4 = up(128, 64)
self.outc = outconv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
# 使用Sigmoid进行归一化
return F.sigmoid(x)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化