加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
predict.py 2.01 KB
一键复制 编辑 原始数据 按行查看 历史
lllssskkk 提交于 2024-09-22 21:52 . dice loss
import glob
import numpy as np
import torch
import cv2
import json
from model.unet_model import UNet
import os
def predict(net,image_path,seg_path):
images_path = glob.glob(image_path)
for i in range(len(images_path)):
image_path = images_path[i]
# 保存结果地址
i_str = f"{(i+1):02d}"
save_res_path = seg_path.replace("*", i_str)
# 读取图片
img = cv2.imread(image_path)
origin_shape = img.shape
# 图片预测
# 转为大小为batch为1,通道数为1*512*512的数组
img = cv2.resize(img, (512, 512))
img = img.transpose()
img = img.reshape(1, 3, 512, 512)
# 转为tensor
img_tensor = torch.from_numpy(img)
# 将tensor拷贝到device中
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
# 预测
pred = net(img_tensor)
pred = torch.sigmoid(pred)
# 提取结果
pred = np.array(pred.data.cpu()[0])[0]
print(pred.max(),pred.min())
# 处理结果
pred[pred >= 0.51] = 255
pred[pred < 0.51] = 0
# 保存图片
pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)
cv2.imwrite(save_res_path, pred)
if __name__ == "__main__":
# 加载网络,三通道一分类
net = UNet(n_channels=3, n_classes=1)
# 将网络拷贝到deivce中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)
# 加载模型参数
dir = str(os.path.dirname(__file__))
# print(dir)
model = torch.load(dir+'/model.pth',weights_only=True, map_location=device)
net.load_state_dict(model['model_state_dict'])
print(model['loss'])
# 测试模式
net.eval()
# 读取图片路径
with open("configs.json", "r") as f:
config = json.load(f)
image_path = config["dataset"]["test"]["images_path"]
seg_path = config["dataset"]["test"]["seg_path"]
predict(net,image_path,seg_path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化