代码拉取完成,页面将自动刷新
同步操作将从 luoyongcoder/unet_42 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# -*- coding: utf-8 -*-
"""
-------------------------------------------------
Project Name: unet
File Name: test.py
Author: chenming
Create Date: 2022/2/7
Description:
-------------------------------------------------
"""
import os
from tqdm import tqdm
from utils.utils_metrics import compute_mIoU, show_results
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
def cal_miou(test_dir="C:/Users/chenmingsong/Desktop/unetnnn/skin/Test_Images",
pred_dir="C:/Users/chenmingsong/Desktop/unetnnn/skin/results", gt_dir="C:/Users/chenmingsong/Desktop/unetnnn/skin/Test_Labels"):
# ---------------------------------------------------------------------------#
# miou_mode用于指定该文件运行时计算的内容
# miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。
# miou_mode为1代表仅仅获得预测结果。
# miou_mode为2代表仅仅计算miou。
# ---------------------------------------------------------------------------#
miou_mode = 0
# ------------------------------#
# 分类个数+1、如2+1
# ------------------------------#
num_classes = 2
# --------------------------------------------#
# 区分的种类,和json_to_dataset里面的一样
# --------------------------------------------#
name_classes = ["background", "nidus"]
# name_classes = ["_background_","cat","dog"]
# -------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
# -------------------------------------------------------#
# 计算结果和gt的结果进行比对
# 加载模型
if miou_mode == 0 or miou_mode == 1:
if not os.path.exists(pred_dir):
os.makedirs(pred_dir)
print("Load model.")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道,分类为1。
net = UNet(n_channels=1, n_classes=1)
# 将网络拷贝到deivce中
net.to(device=device)
# 加载模型参数
net.load_state_dict(torch.load('best_model_skin.pth', map_location=device)) # todo
# 测试模式
net.eval()
print("Load model done.")
img_names = os.listdir(test_dir)
image_ids = [image_name.split(".")[0] for image_name in img_names]
print("Get predict result.")
for image_id in tqdm(image_ids):
image_path = os.path.join(test_dir, image_id + ".jpg")
img = cv2.imread(image_path)
origin_shape = img.shape
# print(origin_shape)
# 转为灰度图
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = cv2.resize(img, (512, 512))
# 转为batch为1,通道为1,大小为512*512的数组
img = img.reshape(1, 1, img.shape[0], img.shape[1])
# 转为tensor
img_tensor = torch.from_numpy(img)
# 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
# 预测
pred = net(img_tensor)
# 提取结果
pred = np.array(pred.data.cpu()[0])[0]
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)
print("Get predict result done.")
if miou_mode == 0 or miou_mode == 2:
print("Get miou.")
print(gt_dir)
print(pred_dir)
print(num_classes)
print(name_classes)
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes,
name_classes) # 执行计算mIoU的函数
print("Get miou done.")
miou_out_path = "results/"
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
if __name__ == '__main__':
cal_miou()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。