加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
infer.py 1.41 KB
一键复制 编辑 原始数据 按行查看 历史
Shivelino 提交于 2023-12-26 16:48 . chore: 少量修改
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file infer.py
@brief
@details
@author Shivelino
@date 2023-12-23 19:10
@version 0.0.1
@par Copyright(c):
@par todo:
@par history:
"""
import torch
import argparse
import cv2
import torchvision.transforms as transforms
from nets import get_model
from utils import get_device
def infer(opt):
# load model
device = get_device()
model = get_model(opt.model).to(device)
model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))
softmax = torch.nn.Softmax(dim=0)
# read image
img_np = cv2.imread(opt.img_path, cv2.IMREAD_GRAYSCALE)
img_np = cv2.resize(img_np, (28, 28))
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
img_tensor = transform(img_np).unsqueeze_(0)
# infer
model.eval()
with torch.no_grad():
img_tensor = img_tensor.to(device)
outputs = model(img_tensor).to("cpu")
output = softmax(outputs[0])
result = int(torch.argmax(output))
confidence = output[result]
print(f"Hand-writing number: {result}; confidence: {confidence * 100: .2f}%")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="lenet", help='model')
parser.add_argument('--img_path', type=str, default="data/img/0.jpg", help='image path to infer')
infer(parser.parse_args())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化