加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 3.76 KB
一键复制 编辑 原始数据 按行查看 历史
五粮液 提交于 2024-06-25 01:11 . Initial commit
import os, time
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from operator import add
import numpy as np
from glob import glob
import cv2
from tqdm import tqdm
import imageio
import torch
from model import New_UNet
from utils import create_dir, seeding
from utils import calculate_metrics
from train import load_data
def evaluate(model, save_path, test_x, test_y, size):
""" Loading other comparitive model masks """
metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []
for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
name = y.split("/")[-1].split(".")[0]
""" Image """
image = cv2.imread(x, cv2.IMREAD_COLOR)
image = cv2.resize(image, (256,256))
save_img = image
image = np.transpose(image, (2, 0, 1))
image = image/255.0
image = np.expand_dims(image, axis=0)
image = image.astype(np.float32)
image = torch.from_numpy(image)
image = image.to(device)
""" Mask """
mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (256,256))
save_mask = mask
save_mask = np.expand_dims(save_mask, axis=-1)
save_mask = np.concatenate([save_mask, save_mask, save_mask], axis=2)
mask = np.expand_dims(mask, axis=0)
mask = mask/255.0
mask = np.expand_dims(mask, axis=0)
mask = mask.astype(np.float32)
mask = torch.from_numpy(mask)
mask = mask.to(device)
with torch.no_grad():
""" FPS calculation """
start_time = time.time()
y_pred= model(image)
y_pred = torch.sigmoid(y_pred)
end_time = time.time() - start_time
time_taken.append(end_time)
""" Evaluation metrics """
score = calculate_metrics(mask, y_pred)
metrics_score = list(map(add, metrics_score, score))
""" Predicted Mask """
y_pred = y_pred[0].cpu().numpy()
y_pred = np.squeeze(y_pred, axis=0)
y_pred = y_pred > 0.5
y_pred = y_pred.astype(np.int32)
y_pred = y_pred * 255
y_pred = np.array(y_pred, dtype=np.uint8)
y_pred = np.expand_dims(y_pred, axis=-1)
y_pred = np.concatenate([y_pred, y_pred, y_pred], axis=2)
""" Save the image - mask - pred """
line = np.ones((size[0], 10, 3)) * 255
cv2.imwrite(f"{save_path}/mask/{name}.jpg", y_pred)
jaccard = metrics_score[0]/len(test_x)
f1 = metrics_score[1]/len(test_x)
recall = metrics_score[2]/len(test_x)
precision = metrics_score[3]/len(test_x)
acc = metrics_score[4]/len(test_x)
f2 = metrics_score[5]/len(test_x)
print(f"IoU: {jaccard:1.4f} - Dice: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - F2: {f2:1.4f}")
mean_time_taken = np.mean(time_taken)
mean_fps = 1/mean_time_taken
print("Mean FPS: ", mean_fps)
if __name__ == "__main__":
""" Seeding """
seeding(42)
""" Load the checkpoint """
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_name = 'New_UNet'
dataset_name = 'ISIC2018'
model = New_UNet(n_channels=3,num_classes=1)
model = model.to(device)
checkpoint_path = f"./files/{model_name}/{dataset_name}/checkpoint.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
""" Test dataset """
path = f"./Data/{dataset_name}/TestDataset/"
(test_x, test_y) = load_data(path,'test')
test_x = sorted(test_x)
test_y = sorted(test_y)
save_path = f"results/{model_name}"
for item in "mask":
create_dir(f"{save_path}/{item}")
size = (256, 256)
create_dir(save_path)
evaluate(model, save_path, test_x, test_y, size)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化