代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
class CBDNet(nn.Module):
def __init__(self):
super(CBDNet, self).__init__()
self.fcn = FCN()
self.unet = UNet()
def forward(self, x):
noise_level = self.fcn(x)
concat_img = torch.cat([x, noise_level], dim=1)
out = self.unet(concat_img) + x
return noise_level, out
class FCN(nn.Module):
def __init__(self):
super(FCN, self).__init__()
self.inc = nn.Sequential(
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(inplace=True)
)
self.conv = nn.Sequential(
nn.Conv2d(32, 32, 3, padding=1),
nn.ReLU(inplace=True)
)
self.outc = nn.Sequential(
nn.Conv2d(32, 3, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
conv1 = self.inc(x)
conv2 = self.conv(conv1)
conv3 = self.conv(conv2)
conv4 = self.conv(conv3)
conv5 = self.outc(conv4)
return conv5
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.inc = nn.Sequential(
single_conv(6, 64),
single_conv(64, 64)
)
self.down1 = nn.AvgPool2d(2)
self.conv1 = nn.Sequential(
single_conv(64, 128),
single_conv(128, 128),
single_conv(128, 128)
)
self.down2 = nn.AvgPool2d(2)
self.conv2 = nn.Sequential(
single_conv(128, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256),
single_conv(256, 256)
)
self.up1 = up(256)
self.conv3 = nn.Sequential(
single_conv(128, 128),
single_conv(128, 128),
single_conv(128, 128)
)
self.up2 = up(128)
self.conv4 = nn.Sequential(
single_conv(64, 64),
single_conv(64, 64)
)
self.outc = outconv(64, 3)
def forward(self, x):
inx = self.inc(x)
down1 = self.down1(inx)
conv1 = self.conv1(down1)
down2 = self.down2(conv1)
conv2 = self.conv2(down2)
up1 = self.up1(conv2, conv1)
conv3 = self.conv3(up1)
up2 = self.up2(conv3, inx)
conv4 = self.conv4(up2)
out = self.outc(conv4)
return out
class single_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(single_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class up(nn.Module):
def __init__(self, in_ch):
super(up, self).__init__()
self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
x = x2 + x1
return x
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x
class fixed_loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, out_image, gt_image, est_noise, gt_noise, if_asym):
h_x = est_noise.size()[2]
w_x = est_noise.size()[3]
count_h = self._tensor_size(est_noise[:, :, 1:, :])
count_w = self._tensor_size(est_noise[:, :, : ,1:])
h_tv = torch.pow((est_noise[:, :, 1:, :] - est_noise[:, :, :h_x-1, :]), 2).sum()
w_tv = torch.pow((est_noise[:, :, :, 1:] - est_noise[:, :, :, :w_x-1]), 2).sum()
tvloss = h_tv / count_h + w_tv / count_w
loss = torch.mean(torch.pow((out_image - gt_image), 2)) + \
if_asym * 0.5 * torch.mean(torch.mul(torch.abs(0.3 - F.relu(gt_noise - est_noise)), torch.pow(est_noise - gt_noise, 2))) + \
0.05 * tvloss
return loss
def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。