加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
loss.py 2.13 KB
一键复制 编辑 原始数据 按行查看 历史
Alex Damian 提交于 2020-05-20 14:05 . Initial refactor
import torch
from bicubic import BicubicDownSample
class LossBuilder(torch.nn.Module):
def __init__(self, ref_im, loss_str, eps):
super(LossBuilder, self).__init__()
assert ref_im.shape[2]==ref_im.shape[3]
im_size = ref_im.shape[2]
factor=1024//im_size
assert im_size*factor==1024
self.D = BicubicDownSample(factor=factor)
self.ref_im = ref_im
self.parsed_loss = [loss_term.split('*') for loss_term in loss_str.split('+')]
self.eps = eps
# Takes a list of tensors, flattens them, and concatenates them into a vector
# Used to calculate euclidian distance between lists of tensors
def flatcat(self, l):
l = l if(isinstance(l, list)) else [l]
return torch.cat([x.flatten() for x in l], dim=0)
def _loss_l2(self, gen_im_lr, ref_im, **kwargs):
return ((gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum())
def _loss_l1(self, gen_im_lr, ref_im, **kwargs):
return 10*((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum())
# Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors
def _loss_geocross(self, latent, **kwargs):
if(latent.shape[1] == 1):
return 0
else:
X = latent.view(-1, 1, 18, 512)
Y = latent.view(-1, 18, 1, 512)
A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt()
B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt()
D = 2*torch.atan2(A, B)
D = ((D.pow(2)*512).mean((1, 2))/8.).sum()
return D
def forward(self, latent, gen_im):
var_dict = {'latent': latent,
'gen_im_lr': self.D(gen_im),
'ref_im': self.ref_im,
}
loss = 0
loss_fun_dict = {
'L2': self._loss_l2,
'L1': self._loss_l1,
'GEOCROSS': self._loss_geocross,
}
losses = {}
for weight, loss_type in self.parsed_loss:
tmp_loss = loss_fun_dict[loss_type](**var_dict)
losses[loss_type] = tmp_loss
loss += float(weight)*tmp_loss
return loss, losses
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化