代码拉取完成,页面将自动刷新
同步操作将从 五粮液/MobileNet 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os, time
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from operator import add
import numpy as np
from glob import glob
import cv2
from tqdm import tqdm
import imageio
import torch
from model import New_UNet
from utils import create_dir, seeding
from utils import calculate_metrics
from train import load_data
def evaluate(model, save_path, test_x, test_y, size):
""" Loading other comparitive model masks """
metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []
for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
name = y.split("/")[-1].split(".")[0]
""" Image """
image = cv2.imread(x, cv2.IMREAD_COLOR)
image = cv2.resize(image, (256,256))
save_img = image
image = np.transpose(image, (2, 0, 1))
image = image/255.0
image = np.expand_dims(image, axis=0)
image = image.astype(np.float32)
image = torch.from_numpy(image)
image = image.to(device)
""" Mask """
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (256,256))
save_mask = mask
save_mask = np.expand_dims(save_mask, axis=-1)
save_mask = np.concatenate([save_mask, save_mask, save_mask], axis=2)
mask = np.expand_dims(mask, axis=0)
mask = mask/255.0
mask = np.expand_dims(mask, axis=0)
mask = mask.astype(np.float32)
mask = torch.from_numpy(mask)
mask = mask.to(device)
with torch.no_grad():
""" FPS calculation """
start_time = time.time()
y_pred= model(image)
y_pred = torch.sigmoid(y_pred)
end_time = time.time() - start_time
time_taken.append(end_time)
""" Evaluation metrics """
score = calculate_metrics(mask, y_pred)
metrics_score = list(map(add, metrics_score, score))
""" Predicted Mask """
y_pred = y_pred[0].cpu().numpy()
y_pred = np.squeeze(y_pred, axis=0)
y_pred = y_pred > 0.5
y_pred = y_pred.astype(np.int32)
y_pred = y_pred * 255
y_pred = np.array(y_pred, dtype=np.uint8)
y_pred = np.expand_dims(y_pred, axis=-1)
y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=2)
""" Save the image - mask - pred """
line = np.ones((size[0], 10, 3)) * 255
cv2.imwrite(f"{save_path}/mask/{name}.jpg", y_pred)
jaccard = metrics_score[0]/len(test_x)
f1 = metrics_score[1]/len(test_x)
recall = metrics_score[2]/len(test_x)
precision = metrics_score[3]/len(test_x)
acc = metrics_score[4]/len(test_x)
f2 = metrics_score[5]/len(test_x)
print(f"IoU: {jaccard:1.4f} - Dice: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f}")
mean_time_taken = np.mean(time_taken)
mean_fps = 1/mean_time_taken
print("Mean FPS: ", mean_fps)
if __name__ == "__main__":
""" Seeding """
seeding(42)
""" Load the checkpoint """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'New_UNet'
dataset_name = 'ISIC2018'
model = New_UNet(n_channels=3,num_classes=1)
model = model.to(device)
checkpoint_path = f"./files/{model_name}/{dataset_name}/checkpoint.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
""" Test dataset """
path = f"./Data/{dataset_name}/TestDataset/"
(test_x, test_y) = load_data(path,'test')
test_x = sorted(test_x)
test_y = sorted(test_y)
save_path = f"results/{model_name}"
for item in "mask":
create_dir(f"{save_path}/{item}")
size = (256, 256)
create_dir(save_path)
evaluate(model, save_path, test_x, test_y, size)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。