代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file infer.py
@brief
@details
@author Shivelino
@date 2023-12-23 19:10
@version 0.0.1
@par Copyright(c):
@par todo:
@par history:
"""
import torch
import argparse
import cv2
import torchvision.transforms as transforms
from nets import get_model
from utils import get_device
def infer(opt):
# load model
device = get_device()
model = get_model(opt.model).to(device)
model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))
softmax = torch.nn.Softmax(dim=0)
# read image
img_np = cv2.imread(opt.img_path, cv2.IMREAD_GRAYSCALE)
img_np = cv2.resize(img_np, (28, 28))
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
img_tensor = transform(img_np).unsqueeze_(0)
# infer
model.eval()
with torch.no_grad():
img_tensor = img_tensor.to(device)
outputs = model(img_tensor).to("cpu")
output = softmax(outputs[0])
result = int(torch.argmax(output))
confidence = output[result]
print(f"Hand-writing number: {result}; confidence: {confidence * 100: .2f}%")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="lenet", help='model')
parser.add_argument('--img_path', type=str, default="data/img/0.jpg", help='image path to infer')
infer(parser.parse_args())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。