加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
demo.py 6.66 KB
一键复制 编辑 原始数据 按行查看 历史
Shivelino 提交于 2023-12-26 16:48 . chore: 少量修改
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file demo.py
@brief
@details
@author Shivelino
@date 2023-12-23 19:10
@version 0.0.1
@par Copyright(c):
@par todo:
@par history:
"""
import cv2
import torch
import argparse
import torchvision.transforms as transforms
from nets import get_model
from utils import get_device
from PyQt5 import QtCore, QtGui, QtWidgets
import sys
from PyQt5.QtWidgets import QApplication, QLabel, QMainWindow
from PyQt5.QtGui import QPainter, QPen
from PyQt5.QtCore import Qt
import numpy as np
def qimage2opencv(image):
"""QImage to opencv image(numpy array)"""
width = image.width()
height = image.height()
image_opencv = np.zeros((height, width, 3), dtype=np.uint8)
for y in range(height):
for x in range(width):
pixel_color = image.pixelColor(x, y)
image_opencv[y, x] = [pixel_color.blue(), pixel_color.green(), pixel_color.red()] # BGR
return image_opencv
class Inferior(object):
def __init__(self, opt):
# load model
self.device = get_device()
self.model = get_model(opt.model).to(self.device)
self.model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))
self.model.eval()
self.softmax = torch.nn.Softmax(dim=0)
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
def infer(self, img_np): # img_np: 1,28,28
img_tensor = self.transform(img_np).unsqueeze_(0) # to tensor
# infer
result, confidence = -1, -1
with torch.no_grad():
img_tensor = img_tensor.to(self.device)
outputs = self.model(img_tensor).to("cpu")
output = self.softmax(outputs[0])
result = int(torch.argmax(output))
confidence = output[result]
# print(f"Hand-writing number: {result}; confidence: {confidence * 100: .2f}%")
return result, confidence
class MainWindow(QMainWindow):
def __init__(self, inferior):
super().__init__()
self.inferior = inferior # 推理器
self.points = []
self.setupUi(self)
def setupUi(self, mainwindow):
mainwindow.setObjectName("mainwindow")
mainwindow.resize(1280, 720)
self.centralwidget = QtWidgets.QWidget(mainwindow)
self.centralwidget.setObjectName("centralwidget")
self.boarder = QLabel(self.centralwidget)
self.boarder.setGeometry(QtCore.QRect(10, 10, 700, 700))
self.boarder.setMinimumSize(QtCore.QSize(700, 700))
self.boarder.setMaximumSize(QtCore.QSize(700, 700))
self.boarder.setCursor(QtGui.QCursor(QtCore.Qt.CrossCursor))
self.boarder.setObjectName("boarder")
self.boarder.setStyleSheet("border: 1px solid black;")
self.boarder.setFixedSize(700, 700)
self.boarder.setPixmap(self.boarder.grab())
self.blank_board = self.boarder.pixmap()
self.result = QtWidgets.QLabel(self.centralwidget)
self.result.setGeometry(QtCore.QRect(720, 180, 550, 160))
font = QtGui.QFont()
font.setPointSize(64)
font.setBold(True)
self.result.setFont(font)
self.result.setAlignment(QtCore.Qt.AlignCenter)
self.result.setObjectName("result")
self.result.setStyleSheet("border: 1px solid black;")
self.text1 = QtWidgets.QLabel(self.centralwidget)
self.text1.setGeometry(QtCore.QRect(720, 10, 550, 160))
font = QtGui.QFont()
font.setFamily("Microsoft YaHei UI")
font.setPointSize(64)
font.setBold(True)
self.text1.setFont(font)
self.text1.setAlignment(QtCore.Qt.AlignCenter)
self.text1.setObjectName("text1")
self.confidence = QtWidgets.QLabel(self.centralwidget)
self.confidence.setGeometry(QtCore.QRect(720, 550, 550, 160))
font = QtGui.QFont()
font.setPointSize(64)
font.setBold(True)
self.confidence.setFont(font)
self.confidence.setAlignment(QtCore.Qt.AlignCenter)
self.confidence.setObjectName("confidence")
self.confidence.setStyleSheet("border: 1px solid black;")
self.text2 = QtWidgets.QLabel(self.centralwidget)
self.text2.setGeometry(QtCore.QRect(720, 370, 550, 160))
font = QtGui.QFont()
font.setFamily("Microsoft YaHei UI")
font.setPointSize(64)
font.setBold(True)
self.text2.setFont(font)
self.text2.setAlignment(QtCore.Qt.AlignCenter)
self.text2.setObjectName("text2")
mainwindow.setCentralWidget(self.centralwidget)
self.retranslateUi(mainwindow)
QtCore.QMetaObject.connectSlotsByName(mainwindow)
def retranslateUi(self, mainwindow):
_translate = QtCore.QCoreApplication.translate
mainwindow.setWindowTitle(_translate("mainwindow", "手写数字识别演示程序"))
self.text1.setText(_translate("mainwindow", "识别结果"))
self.text2.setText(_translate("mainwindow", "识别置信度"))
def paintEvent(self, event):
painter = QPainter(self.boarder.pixmap())
pen = QPen()
pen.setWidth(98)
pen.setColor(Qt.black)
painter.setPen(pen)
for i in range(1, len(self.points)):
painter.drawLine(self.points[i - 1], self.points[i])
self.update()
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.points = [event.pos()]
elif event.button() == Qt.RightButton:
result, confidence = self.infer()
self.result.setText(f"{result}")
self.confidence.setText(f"{confidence: .4f}")
def mouseMoveEvent(self, event):
pos = event.pos()
pos.setX(pos.x() - 10)
pos.setY(pos.y() - 10)
self.points.append(pos)
def mouseDoubleClickEvent(self, event):
if event.button() == Qt.LeftButton:
# clear board
painter = QPainter(self.boarder.pixmap())
painter.eraseRect(self.boarder.rect())
self.points.clear()
# clear text edit
self.result.setText(f"")
self.confidence.setText(f"")
def infer(self):
# convert img
img_np = qimage2opencv(self.boarder.pixmap().toImage())
img_np = cv2.resize(img_np, (28, 28))
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_np = cv2.bitwise_not(img_np)
return self.inferior.infer(img_np) # infer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default="lenet", help='model')
custom_inferior = Inferior(parser.parse_args())
app = QApplication(sys.argv)
window = MainWindow(custom_inferior)
window.show()
sys.exit(app.exec_())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化