加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 9.89 KB
一键复制 编辑 原始数据 按行查看 历史
Renkai Wu 提交于 2024-04-11 09:59 . Update utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms.functional as TF
import numpy as np
import os
import math
import random
import logging
import logging.handlers
from matplotlib import pyplot as plt
def set_seed(seed):
# for hash
os.environ['PYTHONHASHSEED'] = str(seed)
# for python and numpy
random.seed(seed)
np.random.seed(seed)
# for cpu gpu
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# for cudnn
cudnn.benchmark = False
cudnn.deterministic = True
def get_logger(name, log_dir):
'''
Args:
name(str): name of logger
log_dir(str): path of log
'''
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
info_name = os.path.join(log_dir, '{}.info.log'.format(name))
info_handler = logging.handlers.TimedRotatingFileHandler(info_name,
when='D',
encoding='utf-8')
info_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
info_handler.setFormatter(formatter)
logger.addHandler(info_handler)
return logger
def log_config_info(config, logger):
config_dict = config.__dict__
log_info = f'#----------Config info----------#'
logger.info(log_info)
for k, v in config_dict.items():
if k[0] == '_':
continue
else:
log_info = f'{k}: {v},'
logger.info(log_info)
def get_optimizer(config, model):
assert config.opt in ['Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD', 'RMSprop', 'Rprop', 'SGD'], 'Unsupported optimizer!'
if config.opt == 'Adadelta':
return torch.optim.Adadelta(
model.parameters(),
lr = config.lr,
rho = config.rho,
eps = config.eps,
weight_decay = config.weight_decay
)
elif config.opt == 'Adagrad':
return torch.optim.Adagrad(
model.parameters(),
lr = config.lr,
lr_decay = config.lr_decay,
eps = config.eps,
weight_decay = config.weight_decay
)
elif config.opt == 'Adam':
return torch.optim.Adam(
model.parameters(),
lr = config.lr,
betas = config.betas,
eps = config.eps,
weight_decay = config.weight_decay,
amsgrad = config.amsgrad
)
elif config.opt == 'AdamW':
return torch.optim.AdamW(
model.parameters(),
lr = config.lr,
betas = config.betas,
eps = config.eps,
weight_decay = config.weight_decay,
amsgrad = config.amsgrad
)
elif config.opt == 'Adamax':
return torch.optim.Adamax(
model.parameters(),
lr = config.lr,
betas = config.betas,
eps = config.eps,
weight_decay = config.weight_decay
)
elif config.opt == 'ASGD':
return torch.optim.ASGD(
model.parameters(),
lr = config.lr,
lambd = config.lambd,
alpha = config.alpha,
t0 = config.t0,
weight_decay = config.weight_decay
)
elif config.opt == 'RMSprop':
return torch.optim.RMSprop(
model.parameters(),
lr = config.lr,
momentum = config.momentum,
alpha = config.alpha,
eps = config.eps,
centered = config.centered,
weight_decay = config.weight_decay
)
elif config.opt == 'Rprop':
return torch.optim.Rprop(
model.parameters(),
lr = config.lr,
etas = config.etas,
step_sizes = config.step_sizes,
)
elif config.opt == 'SGD':
return torch.optim.SGD(
model.parameters(),
lr = config.lr,
momentum = config.momentum,
weight_decay = config.weight_decay,
dampening = config.dampening,
nesterov = config.nesterov
)
else: # default opt is SGD
return torch.optim.SGD(
model.parameters(),
lr = 0.01,
momentum = 0.9,
weight_decay = 0.05,
)
def get_scheduler(config, optimizer):
assert config.sch in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'CosineAnnealingLR', 'ReduceLROnPlateau',
'CosineAnnealingWarmRestarts', 'WP_MultiStepLR', 'WP_CosineLR'], 'Unsupported scheduler!'
if config.sch == 'StepLR':
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size = config.step_size,
gamma = config.gamma,
last_epoch = config.last_epoch
)
elif config.sch == 'MultiStepLR':
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones = config.milestones,
gamma = config.gamma,
last_epoch = config.last_epoch
)
elif config.sch == 'ExponentialLR':
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma = config.gamma,
last_epoch = config.last_epoch
)
elif config.sch == 'CosineAnnealingLR':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max = config.T_max,
eta_min = config.eta_min,
last_epoch = config.last_epoch
)
elif config.sch == 'ReduceLROnPlateau':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode = config.mode,
factor = config.factor,
patience = config.patience,
threshold = config.threshold,
threshold_mode = config.threshold_mode,
cooldown = config.cooldown,
min_lr = config.min_lr,
eps = config.eps
)
elif config.sch == 'CosineAnnealingWarmRestarts':
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0 = config.T_0,
T_mult = config.T_mult,
eta_min = config.eta_min,
last_epoch = config.last_epoch
)
elif config.sch == 'WP_MultiStepLR':
lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else config.gamma**len(
[m for m in config.milestones if m <= epoch])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
elif config.sch == 'WP_CosineLR':
lr_func = lambda epoch: epoch / config.warm_up_epochs if epoch <= config.warm_up_epochs else 0.5 * (
math.cos((epoch - config.warm_up_epochs) / (config.epochs - config.warm_up_epochs) * math.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
return scheduler
def save_imgs(img, msk, msk_pred, i, save_path, datasets, threshold=0.5, test_data_name=None):
img = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
img = img / 255. if img.max() > 1.1 else img
if datasets == 'retinal':
msk = np.squeeze(msk, axis=0)
msk_pred = np.squeeze(msk_pred, axis=0)
else:
msk = np.where(np.squeeze(msk, axis=0) > 0.5, 1, 0)
msk_pred = np.where(np.squeeze(msk_pred, axis=0) > threshold, 1, 0)
plt.figure(figsize=(7,15))
plt.subplot(3,1,1)
plt.imshow(img)
plt.axis('off')
plt.subplot(3,1,2)
plt.imshow(msk, cmap= 'gray')
plt.axis('off')
plt.subplot(3,1,3)
plt.imshow(msk_pred, cmap = 'gray')
plt.axis('off')
if test_data_name is not None:
save_path = save_path + test_data_name + '_'
plt.savefig(save_path + str(i) +'.png')
plt.close()
class BCELoss(nn.Module):
def __init__(self):
super(BCELoss, self).__init__()
self.bceloss = nn.BCELoss()
def forward(self, pred, target):
size = pred.size(0)
pred_ = pred.view(size, -1)
target_ = target.view(size, -1)
return self.bceloss(pred_, target_)
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, pred, target):
smooth = 1
size = pred.size(0)
pred_ = pred.view(size, -1)
target_ = target.view(size, -1)
intersection = pred_ * target_
dice_score = (2 * intersection.sum(1) + smooth)/(pred_.sum(1) + target_.sum(1) + smooth)
dice_loss = 1 - dice_score.sum()/size
return dice_loss
class BceDiceLoss(nn.Module):
def __init__(self, wb=1, wd=1):
super(BceDiceLoss, self).__init__()
self.bce = BCELoss()
self.dice = DiceLoss()
self.wb = wb
self.wd = wd
def forward(self, pred, target):
bceloss = self.bce(pred, target)
diceloss = self.dice(pred, target)
loss = self.wd * diceloss + self.wb * bceloss
return loss
from thop import profile ## 导入thop模块
def cal_params_flops(model, size, logger):
input = torch.randn(1, 3, size, size).cuda()
flops, params = profile(model, inputs=(input,))
print('flops',flops/1e9) ## 打印计算量
print('params',params/1e6) ## 打印参数量
total = sum(p.numel() for p in model.parameters())
print("Total params: %.3fM" % (total/1e6))
logger.info(f'flops: {flops/1e9}, params: {params/1e6}, Total params: : {total/1e6:.4f}')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化