加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 10.82 KB
一键复制 编辑 原始数据 按行查看 历史
PandaKeyHub 提交于 2021-06-13 12:13 . 初次提交
# coding: utf-8
import sys
import os
import random
from PyQt5 import QtWidgets, QtCore
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QLineEdit, QMessageBox
from PyQt5.QtGui import QPixmap
import leafdisui
import torch
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
import net.mobilenet as mobilenet
import net.unet as unet
alpha = 1
beta = 0.8 #二张图片的透明度
gamma = 0
color_list = [[255, 0, 0],
[255, 0, 255],
[0, 0, 255]]
Names = [
'健康',
'细菌性叶枯病',
'褐斑病',
'叶黑穗病',
]
class LeafDisCls():
"""构造函数"""
def __init__(self, ui):
self.ui = ui
self.image_pathname_list = [] #存储图像路径的列表
self.idx = 0 #当前图像索引
self.image_num = 0 #图像数量
self.disease_cls = -1 #病害类型
"""设置显卡或CPU"""
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # 使用0号设备的GPU
# 若没有GPU 就是用cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = device
"""加载网络模型"""
model_cls_file = './model/model_cls.pth' # 分类
model_seg_file = './model/model_seg.pth' #分割
# 使用深度卷积学习v1类别
net_cls = mobilenet.MobileNetV1()
net_seg = unet.UNetResNet18(4) # 4种类别
net_cls.to(device)
net_seg.to(device)
if model_cls_file.split('.')[-1] == 'pth':
# 从新加载分类模型
net_cls.load_state_dict(torch.load(model_cls_file))
else:
checkpoint = torch.load(model_cls_file)
net_cls.load_state_dict(checkpoint['model_state_dict'])
# 从新加载分割模型
if model_seg_file.split('.')[-1] == 'pth':
net_seg.load_state_dict(torch.load(model_seg_file))
else:
checkpoint = torch.load(model_seg_file)
net_seg.load_state_dict(checkpoint['model_state_dict'])
net_cls.eval() # 开始计算
net_seg.eval() # 分割开始计算
"""输入图像变换方式(组合)"""
transform=transforms.Compose([
transforms.Resize((64, 64)),# 变成64*64
transforms.ToTensor(), # 从H x W x C 变成C x H x W 范围在0-1之间
transforms.Normalize(mean=0.5,std=0.5)]) #使用均值和标准差对图像进行归一化
self.net_cls = net_cls
self.net_seg = net_seg
self.transform = transform
"""选择测试图像所在文件夹(通过“File”菜单下“Open”操作执行)"""
def open_dir(self):
image_dir = QFileDialog.getExistingDirectory(None,"选择文件夹","./")
image_name_list = os.listdir(image_dir)
for image_name in image_name_list:
image_pathname = image_dir + '/' + image_name
self.image_pathname_list.append(image_pathname)
# 排序图片的名字
random.shuffle(self.image_pathname_list)
# 图片数量
self.image_num = len(self.image_pathname_list)
"""开始测试(点击“运行”按钮后开始执行)"""
def run(self):
# 判断索引和图片数量, 数量为0退出
if self.idx >= self.image_num:
exit()
# 获取解决方法
with open('solution.txt')as f:
text = f.readlines()
# 获取图片
image_pathname = self.image_pathname_list[self.idx]
"""显示原图"""
# 按照宽高显示图片,返回重新进行绘图即可自适应窗口
pix = QPixmap(image_pathname).scaled(self.ui.label.width(), self.ui.label.height())
self.ui.label.setPixmap(pix) # 加载图片
# 反向传播时不进行求导, 也就是不进行梯度下降
with torch.no_grad():
print('[', self.idx, '/', self.image_num, ']', image_pathname)
"""读取原图并进行预处理"""
img_rgb = Image.open(image_pathname) # 打开图片,
img_rgb = img_rgb.convert("RGB") # 转换成rgb
# 颜色进行转换np.asarray(img_rgb)要转换的, cv2.COLOR_BGR2RGB转换的格式, 从BGR转换成RGB
img_bgr = cv2.cvtColor(np.asarray(img_rgb), cv2.COLOR_BGR2RGB)
img_transformed = self.transform(img_rgb) # 转换图片
inputs = img_transformed.unsqueeze(0) # 对图像添加一个维度,
inputs = inputs.to(self.device) # 设置cpu或者GPU
"""进行分类"""
output_cls = self.net_cls(inputs)
# 返回1维度的最大值所在的索引, 病害类别
disease_cls = torch.max(output_cls, 1)
disease_cls = disease_cls.indices.item() # 获取具体值
self.disease_cls = disease_cls # 病害类别,比如: 0,2,4
"""进行分割"""
output_seg = self.net_seg(inputs) #使用cpu或者gpu 进行分割
point_cls = torch.max(output_seg, 1).indices # 返回1维度最大值并进行分割
# 在0维度压缩数据,取出1维度, 并且转换为unit8格式
point_cls = point_cls.squeeze(0).cpu().numpy().astype(np.uint8)
point_cls = Image.fromarray(point_cls) # 在cpu缓冲区进行计算
point_cls = point_cls.convert('L') # 根据指定格式进行转换, 并返回转换后的图片
point_cls = np.asarray(point_cls) # 作为数组处理
# 调整图片大小
point_cls = cv2.resize(point_cls, (img_rgb.width, img_rgb.height))
mask1 = point_cls == 1
mask2 = point_cls == 2
mask3 = point_cls == 3
mask1 = mask1.astype(np.uint8)
mask2 = mask2.astype(np.uint8)
mask3 = mask3.astype(np.uint8)
pos_num = np.sum(mask1) + np.sum(mask2) + np.sum(mask3)
mask_rgb1 = cv2.cvtColor(mask1, cv2.COLOR_GRAY2BGR) # 将mask1从灰度图转换成BGR
mask_rgb2 = cv2.cvtColor(mask2, cv2.COLOR_GRAY2BGR)
mask_rgb3 = cv2.cvtColor(mask3, cv2.COLOR_GRAY2BGR)
mask = mask_rgb1 * color_list[0] # 转化成红色
mask += mask_rgb2 * color_list[1] # 转绿色
mask += mask_rgb3 * color_list[2] # 转换蓝色
mask = mask.astype(np.uint8)
img_seg = cv2.addWeighted(img_bgr, alpha, mask, beta, gamma) #将分割结果叠加到原图上,让原来图片特征更明显
cv2.imwrite('./temp.jpg', img_seg) # 写到文件
"""显示疾病面积占比"""
area_ratio = round(100*pos_num/(img_rgb.width * img_rgb.height) ,2) # 求百分比
show_text = str(area_ratio) + '%'
self.ui.lineEdit_2.setText(show_text) # 将结果设置在窗口
"""显示分类和分割结果"""
show_text = Names[disease_cls]
self.ui.lineEdit.setText(show_text) # 病变类别
show_text = text[disease_cls]
self.ui.textEdit.setText(show_text)# 病变的解决方法
# 显示图片
pix = QPixmap('./temp.jpg').scaled(self.ui.label_5.width(), self.ui.label_5.height())
self.ui.label_5.setPixmap(pix)
self.idx += 1
def save(self):
with open('solution.txt') as f:
text = f.readlines()
with open('solution.txt', 'w') as f:
# 创建空字符串
info = self.ui.textEdit.toPlainText()
# 添加换行符
if info[-1] != '\n':
info = info + '\n'
# 病害类型, info 是处理方法
text[self.disease_cls] = info
lines = ''.join(text) # 写到文件中
f.write(lines)
class UserWindow():
def __init__(self):
"""读取用户信息"""
with open('user.txt') as f:
lines = f.readlines()
if len(lines) > 0:
lines = ''.join(lines)
self.name_passward_dict = eval(lines)
else:
self.name_passward_dict = dict()
self.name_passward_dict['root'] = 'root'
"""主窗体"""
self.MainWindow = QMainWindow()
self.ui = leafdisui.Ui_MainWindow()
self.ui.setupUi(self.MainWindow)
# 显示窗口
self.leafdiscls = LeafDisCls(self.ui)
# 打开图片目录
self.ui.actionOpen.triggered.connect(self.leafdiscls.open_dir)
# 运行
self.ui.pushButton.clicked.connect(self.leafdiscls.run)
# 保存结果
self.ui.pushButton_2.clicked.connect(self.leafdiscls.save)
"""登录窗体"""
self.MainWindow_SignIn = QMainWindow()
self.MainWindow_SignIn.setWindowTitle('登录')
self.MainWindow_SignIn.resize(400, 300)
self.pushButton_signin1 = QtWidgets.QPushButton(self.MainWindow_SignIn)
self.pushButton_signin1.setGeometry(QtCore.QRect(110, 200, 80, 30))
self.pushButton_signin1.setText("登录")
self.pushButton_signin2 = QtWidgets.QPushButton(self.MainWindow_SignIn)
self.pushButton_signin2.setGeometry(QtCore.QRect(200, 200, 80, 30))
self.pushButton_signin2.setText("注册")
self.label1_signin = QtWidgets.QLabel(self.MainWindow_SignIn)
self.label1_signin.setGeometry(QtCore.QRect(80, 50, 80, 30))
self.label1_signin.setText("用户名:")
self.line_edit1 = QtWidgets.QLineEdit(self.MainWindow_SignIn)
self.line_edit1.setGeometry(QtCore.QRect(140, 52, 120, 25))
self.label1_signin = QtWidgets.QLabel(self.MainWindow_SignIn)
self.label1_signin.setGeometry(QtCore.QRect(80, 100, 80, 30))
self.label1_signin.setText("密 码:")
self.line_edit2 = QtWidgets.QLineEdit(self.MainWindow_SignIn)
self.line_edit2.setGeometry(QtCore.QRect(140, 102, 120, 25))
self.line_edit2.setEchoMode(QLineEdit.Password)
# 登录
self.pushButton_signin1.clicked.connect(self.sin_in)
# 注册
self.pushButton_signin2.clicked.connect(self.register)
self.MainWindow_SignIn.show()
"""用户登录"""
def sin_in(self):
# 输入框不为空
if self.line_edit1.text() != '' and self.line_edit2.text() != '':
username = self.line_edit1.text()# 获取用户名和密码
passward = self.line_edit2.text()
# 判断用户户名是否存在
if username in self.name_passward_dict.keys():
# 用户名和密码一样
if self.name_passward_dict[username] == passward:
self.MainWindow_SignIn.close() #关闭登录窗口
if username != 'root':# 若用户不是root,就只读
self.ui.textEdit.setReadOnly(True)
self.ui.pushButton_2.setEnabled(False) # 不能修改
self.MainWindow.show()
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '密码错误!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
# 用户名不存在
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户不存在,请先注册!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户名或密码不能为空!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
"""用户注册"""
def register(self):
if self.line_edit1.text() != '' and self.line_edit2.text() != '':
username = self.line_edit1.text()
passward = self.line_edit2.text()
if username not in self.name_passward_dict.keys():
self.name_passward_dict[username] = passward
with open('user.txt', 'w') as f:
f.write(str(self.name_passward_dict))
msg_box = QMessageBox(QMessageBox.Warning, '提示', '注册成功!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户已存在!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户名或密码不能为空!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
if __name__ == '__main__':
app = QApplication(sys.argv)
MyWindow = UserWindow()
sys.exit(app.exec_())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化