加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
engine.py 5.82 KB
一键复制 编辑 原始数据 按行查看 历史
wurenkai 提交于 2024-04-01 09:36 . Add files via upload
import numpy as np
from tqdm import tqdm
import torch
from torch.cuda.amp import autocast as autocast
from sklearn.metrics import confusion_matrix
from utils import save_imgs
def train_one_epoch(train_loader,
model,
criterion,
optimizer,
scheduler,
epoch,
logger,
config,
scaler=None):
'''
train model for one epoch
'''
# switch to train mode
model.train()
loss_list = []
for iter, data in enumerate(train_loader):
optimizer.zero_grad()
images, targets = data
images, targets = images.cuda(non_blocking=True).float(), targets.cuda(non_blocking=True).float()
if config.amp:
with autocast():
out = model(images)
loss = criterion(out, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
out = model(images)
loss = criterion(out, targets)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
now_lr = optimizer.state_dict()['param_groups'][0]['lr']
if iter % config.print_interval == 0:
log_info = f'train: epoch {epoch}, iter:{iter}, loss: {np.mean(loss_list):.4f}, lr: {now_lr}'
print(log_info)
logger.info(log_info)
scheduler.step()
def val_one_epoch(test_loader,
model,
criterion,
epoch,
logger,
config):
# switch to evaluate mode
model.eval()
preds = []
gts = []
loss_list = []
with torch.no_grad():
for data in tqdm(test_loader):
img, msk = data
img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
out = model(img)
loss = criterion(out, msk)
loss_list.append(loss.item())
gts.append(msk.squeeze(1).cpu().detach().numpy())
if type(out) is tuple:
out = out[0]
out = out.squeeze(1).cpu().detach().numpy()
preds.append(out)
if epoch % config.val_interval == 0:
preds = np.array(preds).reshape(-1)
gts = np.array(gts).reshape(-1)
y_pre = np.where(preds>=config.threshold, 1, 0)
y_true = np.where(gts>=0.5, 1, 0)
confusion = confusion_matrix(y_true, y_pre)
TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}, miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
print(log_info)
logger.info(log_info)
else:
log_info = f'val epoch: {epoch}, loss: {np.mean(loss_list):.4f}'
print(log_info)
logger.info(log_info)
return np.mean(loss_list)
def test_one_epoch(test_loader,
model,
criterion,
logger,
config,
test_data_name=None):
# switch to evaluate mode
model.eval()
preds = []
gts = []
loss_list = []
with torch.no_grad():
for i, data in enumerate(tqdm(test_loader)):
img, msk = data
img, msk = img.cuda(non_blocking=True).float(), msk.cuda(non_blocking=True).float()
out = model(img)
loss = criterion(out, msk)
loss_list.append(loss.item())
msk = msk.squeeze(1).cpu().detach().numpy()
gts.append(msk)
if type(out) is tuple:
out = out[0]
out = out.squeeze(1).cpu().detach().numpy()
preds.append(out)
save_imgs(img, msk, out, i, config.work_dir + 'outputs/', config.datasets, config.threshold, test_data_name=test_data_name)
preds = np.array(preds).reshape(-1)
gts = np.array(gts).reshape(-1)
y_pre = np.where(preds>=config.threshold, 1, 0)
y_true = np.where(gts>=0.5, 1, 0)
confusion = confusion_matrix(y_true, y_pre)
TN, FP, FN, TP = confusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]
accuracy = float(TN + TP) / float(np.sum(confusion)) if float(np.sum(confusion)) != 0 else 0
sensitivity = float(TP) / float(TP + FN) if float(TP + FN) != 0 else 0
specificity = float(TN) / float(TN + FP) if float(TN + FP) != 0 else 0
f1_or_dsc = float(2 * TP) / float(2 * TP + FP + FN) if float(2 * TP + FP + FN) != 0 else 0
miou = float(TP) / float(TP + FP + FN) if float(TP + FP + FN) != 0 else 0
if test_data_name is not None:
log_info = f'test_datasets_name: {test_data_name}'
print(log_info)
logger.info(log_info)
log_info = f'test of best model, loss: {np.mean(loss_list):.4f},miou: {miou}, f1_or_dsc: {f1_or_dsc}, accuracy: {accuracy}, \
specificity: {specificity}, sensitivity: {sensitivity}, confusion_matrix: {confusion}'
print(log_info)
logger.info(log_info)
return np.mean(loss_list)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化