加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
trainer.py 19.22 KB
一键复制 编辑 原始数据 按行查看 历史
chrisjuniorli 提交于 2023-07-21 07:49 . update
# ------------------------------------------------------------------------
# Modified from MGMatting (https://github.com/yucornetto/MGMatting)
# ------------------------------------------------------------------------
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as nn_utils
import torch.backends.cudnn as cudnn
from torch.nn import SyncBatchNorm
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel
import utils
from utils import CONFIG
import networks
import wandb
import cv2
class Trainer(object):
def __init__(self,
train_dataloader,
test_dataloader,
logger,
tb_logger):
cudnn.benchmark = True
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.logger = logger
self.tb_logger = tb_logger
self.model_config = CONFIG.model
self.train_config = CONFIG.train
self.log_config = CONFIG.log
self.loss_dict = {'rec_os8': None,
'comp_os8': None,
'rec_os1': None,
'comp_os1': None,
'smooth_l1':None,
'grad':None,
'gabor':None,
'lap_os8': None,
'lap_os1': None,
'rec_os4': None,
'comp_os4': None,
'lap_os4': None,}
self.test_loss_dict = {'rec': None,
'smooth_l1':None,
'mse':None,
'sad':None,
'grad':None,
'gabor':None}
self.grad_filter = torch.tensor(utils.get_gradfilter()).cuda()
self.gabor_filter = torch.tensor(utils.get_gaborfilter(16)).cuda()
self.gauss_filter = torch.tensor([[1., 4., 6., 4., 1.],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]]).cuda()
self.gauss_filter /= 256.
self.gauss_filter = self.gauss_filter.repeat(1, 1, 1, 1)
self.build_model()
self.resume_step = None
self.best_loss = 1e+8
utils.print_network(self.G, CONFIG.version)
if self.train_config.resume_checkpoint:
self.logger.info('Resume checkpoint: {}'.format(self.train_config.resume_checkpoint))
self.restore_model(self.train_config.resume_checkpoint)
if self.model_config.imagenet_pretrain and self.train_config.resume_checkpoint is None:
self.logger.info('Load Imagenet Pretrained: {}'.format(self.model_config.imagenet_pretrain_path))
if self.model_config.arch.encoder == "vgg_encoder":
utils.load_VGG_pretrain(self.G, self.model_config.imagenet_pretrain_path)
else:
utils.load_imagenet_pretrain(self.G, self.model_config.imagenet_pretrain_path)
def build_model(self):
self.G = networks.get_generator_m2m(seg=self.model_config.arch.seg, m2m=self.model_config.arch.m2m)
self.G.cuda()
if CONFIG.dist:
self.logger.info("Using pytorch synced BN")
self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)
self.G_optimizer = torch.optim.Adam(self.G.parameters(),
lr = self.train_config.G_lr,
betas = [self.train_config.beta1, self.train_config.beta2])
if CONFIG.dist:
# SyncBatchNorm only supports DistributedDataParallel with single GPU per process
self.G = DistributedDataParallel(self.G, device_ids=[CONFIG.local_rank], output_device=CONFIG.local_rank, find_unused_parameters=True)
else:
self.G = nn.DataParallel(self.G)
self.build_lr_scheduler()
def build_lr_scheduler(self):
"""Build cosine learning rate scheduler."""
self.G_scheduler = lr_scheduler.CosineAnnealingLR(self.G_optimizer,
T_max=self.train_config.total_step
- self.train_config.warmup_step)
def reset_grad(self):
"""Reset the gradient buffers."""
self.G_optimizer.zero_grad()
def restore_model(self, resume_checkpoint):
"""
Restore the trained generator and discriminator.
:param resume_checkpoint: File name of checkpoint
:return:
"""
pth_path = os.path.join(self.log_config.checkpoint_path, '{}.pth'.format(resume_checkpoint))
checkpoint = torch.load(pth_path, map_location = lambda storage, loc: storage.cuda(CONFIG.gpu))
self.resume_step = checkpoint['iter']
self.logger.info('Loading the trained models from step {}...'.format(self.resume_step))
self.G.load_state_dict(checkpoint['state_dict'], strict=True)
if not self.train_config.reset_lr:
if 'opt_state_dict' in checkpoint.keys():
try:
self.G_optimizer.load_state_dict(checkpoint['opt_state_dict'])
except ValueError as ve:
self.logger.error("{}".format(ve))
else:
self.logger.info('No Optimizer State Loaded!!')
if 'lr_state_dict' in checkpoint.keys():
try:
self.G_scheduler.load_state_dict(checkpoint['lr_state_dict'])
except ValueError as ve:
self.logger.error("{}".format(ve))
else:
self.G_scheduler = lr_scheduler.CosineAnnealingLR(self.G_optimizer,
T_max=self.train_config.total_step - self.resume_step - 1)
if 'loss' in checkpoint.keys():
self.best_loss = checkpoint['loss']
def train(self):
data_iter = iter(self.train_dataloader)
if self.train_config.resume_checkpoint:
start = self.resume_step + 1
else:
start = 0
moving_max_grad = 0
moving_grad_moment = 0.999
max_grad = 0
for step in range(start, self.train_config.total_step + 1):
try:
image_dict = next(data_iter)
except:
data_iter = iter(self.train_dataloader)
image_dict = next(data_iter)
image, alpha, trimap, bbox = image_dict['image'], image_dict['alpha'], image_dict['trimap'], image_dict['boxes']
image = image.cuda()
alpha = alpha.cuda()
trimap = trimap.cuda()
bbox = bbox.cuda()
# train() of DistributedDataParallel has no return
self.G.train()
log_info = ""
loss = 0
"""===== Update Learning Rate ====="""
if step < self.train_config.warmup_step and self.train_config.resume_checkpoint is None:
cur_G_lr = utils.warmup_lr(self.train_config.G_lr, step + 1, self.train_config.warmup_step)
utils.update_lr(cur_G_lr, self.G_optimizer)
else:
self.G_scheduler.step()
cur_G_lr = self.G_scheduler.get_lr()[0]
"""===== Forward G ====="""
pred = self.G(image, bbox)
alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred['alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
mask = pred['mask']
weight_os8 = utils.get_unknown_tensor(mask)
weight_os8[...] = 1
if step < self.train_config.warmup_step:
weight_os4 = utils.get_unknown_tensor(mask)
weight_os1 = utils.get_unknown_tensor(mask)
weight_os4[...] = 1
weight_os1[...] = 1
elif step < self.train_config.warmup_step * 3:
if random.randint(0,1) == 0:
weight_os4 = utils.get_unknown_tensor(mask)
weight_os1 = utils.get_unknown_tensor(mask)
else:
weight_os4 = utils.get_unknown_tensor(trimap)
weight_os1 = utils.get_unknown_tensor(trimap)
else:
if random.randint(0,1) == 0:
weight_os4 = utils.get_unknown_tensor(trimap)
weight_os1 = utils.get_unknown_tensor(trimap)
else:
weight_os4 = utils.get_unknown_tensor_from_pred(alpha_pred_os8, rand_width=CONFIG.model.self_refine_width1, train_mode=True)
weight_os1 = utils.get_unknown_tensor_from_pred(alpha_pred_os4, rand_width=CONFIG.model.self_refine_width2, train_mode=True)
if self.train_config.rec_weight > 0:
self.loss_dict['rec_os1'] = self.regression_loss(alpha_pred_os1, alpha, loss_type='l1', weight=weight_os1) * 2 / 5.0 * self.train_config.rec_weight
self.loss_dict['rec_os4'] = self.regression_loss(alpha_pred_os4, alpha, loss_type='l1', weight=weight_os4) * 1 / 5.0 * self.train_config.rec_weight
self.loss_dict['rec_os8'] = self.regression_loss(alpha_pred_os8, alpha, loss_type='l1', weight=weight_os8) * 1 / 5.0 * self.train_config.rec_weight
if self.train_config.comp_weight > 0:
self.loss_dict['comp_os1'] = self.composition_loss(alpha_pred_os1, fg_norm, bg_norm, image, weight=weight_os1) * 2 / 5.0 * self.train_config.comp_weight
self.loss_dict['comp_os4'] = self.composition_loss(alpha_pred_os4, fg_norm, bg_norm, image, weight=weight_os4) * 1 / 5.0 * self.train_config.comp_weight
self.loss_dict['comp_os8'] = self.composition_loss(alpha_pred_os8, fg_norm, bg_norm, image, weight=weight_os8) * 1 / 5.0 * self.train_config.comp_weight
if self.train_config.lap_weight > 0:
self.loss_dict['lap_os1'] = self.lap_loss(logit=alpha_pred_os1, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os1) * 2 / 5.0 * self.train_config.lap_weight
self.loss_dict['lap_os4'] = self.lap_loss(logit=alpha_pred_os4, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os4) * 1 / 5.0 * self.train_config.lap_weight
self.loss_dict['lap_os8'] = self.lap_loss(logit=alpha_pred_os8, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os8) * 1 / 5.0 * self.train_config.lap_weight
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
loss += self.loss_dict[loss_key]
"""===== Back Propagate ====="""
self.reset_grad()
loss.backward()
"""===== Clip Large Gradient ====="""
if self.train_config.clip_grad:
if moving_max_grad == 0:
moving_max_grad = nn_utils.clip_grad_norm_(self.G.parameters(), 1e+6)
max_grad = moving_max_grad
else:
max_grad = nn_utils.clip_grad_norm_(self.G.parameters(), 2 * moving_max_grad)
moving_max_grad = moving_max_grad * moving_grad_moment + max_grad * (
1 - moving_grad_moment)
"""===== Update Parameters ====="""
self.G_optimizer.step()
"""===== Write Log and Tensorboard ====="""
# stdout log
if step % self.log_config.logging_step == 0:
# reduce losses from GPUs
if CONFIG.dist:
self.loss_dict = utils.reduce_tensor_dict(self.loss_dict, mode='mean')
loss = utils.reduce_tensor(loss)
# create logging information
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
log_info += loss_key.upper() + ": {:.4f}, ".format(self.loss_dict[loss_key])
if CONFIG.wandb and CONFIG.local_rank == 0:
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
wandb.log({'lr': cur_G_lr, 'total_loss': loss, loss_key.upper(): self.loss_dict[loss_key]}, step=step)
self.logger.debug("Image tensor shape: {}. Trimap tensor shape: {}".format(image.shape, trimap.shape))
log_info = "[{}/{}], ".format(step, self.train_config.total_step) + log_info
log_info += "lr: {:6f}".format(cur_G_lr)
self.logger.info(log_info)
# tensorboard
if step % self.log_config.tensorboard_step == 0 or step == start: # and step > start:
self.tb_logger.scalar_summary('Loss', loss, step)
# detailed losses
for loss_key in self.loss_dict.keys():
if self.loss_dict[loss_key] is not None:
self.tb_logger.scalar_summary('Loss_' + loss_key.upper(),
self.loss_dict[loss_key], step)
self.tb_logger.scalar_summary('LearnRate', cur_G_lr, step)
if self.train_config.clip_grad:
self.tb_logger.scalar_summary('Moving_Max_Grad', moving_max_grad, step)
self.tb_logger.scalar_summary('Max_Grad', max_grad, step)
if (step % self.log_config.checkpoint_step == 0 or step == self.train_config.total_step) \
and CONFIG.local_rank == 0 and (step > start):
self.logger.info('Saving the trained models from step {}...'.format(iter))
self.save_model("model_step_{}".format(step), step, loss)
torch.cuda.empty_cache()
def save_model(self, checkpoint_name, iter, loss):
torch.save({
'iter': iter,
'loss': loss,
'state_dict': self.G.module.m2m.state_dict(),
'opt_state_dict': self.G_optimizer.state_dict(),
'lr_state_dict': self.G_scheduler.state_dict()
}, os.path.join(self.log_config.checkpoint_path, '{}.pth'.format(checkpoint_name)))
@staticmethod
def regression_loss(logit, target, loss_type='l1', weight=None):
"""
Alpha reconstruction loss
:param logit:
:param target:
:param loss_type: "l1" or "l2"
:param weight: tensor with shape [N,1,H,W] weights for each pixel
:return:
"""
if weight is None:
if loss_type == 'l1':
return F.l1_loss(logit, target)
elif loss_type == 'l2':
return F.mse_loss(logit, target)
else:
raise NotImplementedError("NotImplemented loss type {}".format(loss_type))
else:
if loss_type == 'l1':
return F.l1_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
elif loss_type == 'l2':
return F.mse_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
else:
raise NotImplementedError("NotImplemented loss type {}".format(loss_type))
@staticmethod
def smooth_l1(logit, target, weight):
loss = torch.sqrt((logit * weight - target * weight)**2 + 1e-6)
loss = torch.sum(loss) / (torch.sum(weight) + 1e-8)
return loss
@staticmethod
def mse(logit, target, weight):
# return F.mse_loss(logit * weight, target * weight, reduction='sum') / (torch.sum(weight) + 1e-8)
return Trainer.regression_loss(logit, target, loss_type='l2', weight=weight)
@staticmethod
def sad(logit, target, weight):
return F.l1_loss(logit * weight, target * weight, reduction='sum') / 1000
@staticmethod
def composition_loss(alpha, fg, bg, image, weight, loss_type='l1'):
"""
Alpha composition loss
"""
merged = fg * alpha + bg * (1 - alpha)
return Trainer.regression_loss(merged, image, loss_type=loss_type, weight=weight)
@staticmethod
def gabor_loss(logit, target, gabor_filter, loss_type='l2', weight=None):
""" pass """
gabor_logit = F.conv2d(logit, weight=gabor_filter, padding=2)
gabor_target = F.conv2d(target, weight=gabor_filter, padding=2)
return Trainer.regression_loss(gabor_logit, gabor_target, loss_type=loss_type, weight=weight)
@staticmethod
def grad_loss(logit, target, grad_filter, loss_type='l1', weight=None):
""" pass """
grad_logit = F.conv2d(logit, weight=grad_filter, padding=1)
grad_target = F.conv2d(target, weight=grad_filter, padding=1)
grad_logit = torch.sqrt((grad_logit * grad_logit).sum(dim=1, keepdim=True) + 1e-8)
grad_target = torch.sqrt((grad_target * grad_target).sum(dim=1, keepdim=True) + 1e-8)
return Trainer.regression_loss(grad_logit, grad_target, loss_type=loss_type, weight=weight)
@staticmethod
def lap_loss(logit, target, gauss_filter, loss_type='l1', weight=None):
'''
Based on FBA Matting implementation:
https://gist.github.com/MarcoForte/a07c40a2b721739bb5c5987671aa5270
'''
def conv_gauss(x, kernel):
x = F.pad(x, (2,2,2,2), mode='reflect')
x = F.conv2d(x, kernel, groups=x.shape[1])
return x
def downsample(x):
return x[:, :, ::2, ::2]
def upsample(x, kernel):
N, C, H, W = x.shape
cc = torch.cat([x, torch.zeros(N,C,H,W).cuda()], dim = 3)
cc = cc.view(N, C, H*2, W)
cc = cc.permute(0,1,3,2)
cc = torch.cat([cc, torch.zeros(N, C, W, H*2).cuda()], dim = 3)
cc = cc.view(N, C, W*2, H*2)
x_up = cc.permute(0,1,3,2)
return conv_gauss(x_up, kernel=4*gauss_filter)
def lap_pyramid(x, kernel, max_levels=3):
current = x
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
down = downsample(filtered)
up = upsample(down, kernel)
diff = current - up
pyr.append(diff)
current = down
return pyr
def weight_pyramid(x, max_levels=3):
current = x
pyr = []
for level in range(max_levels):
down = downsample(current)
pyr.append(current)
current = down
return pyr
pyr_logit = lap_pyramid(x = logit, kernel = gauss_filter, max_levels = 5)
pyr_target = lap_pyramid(x = target, kernel = gauss_filter, max_levels = 5)
if weight is not None:
pyr_weight = weight_pyramid(x = weight, max_levels = 5)
return sum(Trainer.regression_loss(A[0], A[1], loss_type=loss_type, weight=A[2]) * (2**i) for i, A in enumerate(zip(pyr_logit, pyr_target, pyr_weight)))
else:
return sum(Trainer.regression_loss(A[0], A[1], loss_type=loss_type, weight=None) * (2**i) for i, A in enumerate(zip(pyr_logit, pyr_target)))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化