代码拉取完成,页面将自动刷新
同步操作将从 PandaKeyHub/rice_leaf_desease_test 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# coding: utf-8
import sys
import os
import random
from PyQt5 import QtWidgets, QtCore
from PyQt5.QtWidgets import QApplication, QMainWindow, QFileDialog, QLineEdit, QMessageBox
from PyQt5.QtGui import QPixmap
import leafdisui
import torch
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
import net.mobilenet as mobilenet
import net.unet as unet
alpha = 1
beta = 0.8 #二张图片的透明度
gamma = 0
color_list = [[255, 0, 0],
[255, 0, 255],
[0, 0, 255]]
Names = [
'健康',
'细菌性叶枯病',
'褐斑病',
'叶黑穗病',
]
class LeafDisCls():
"""构造函数"""
def __init__(self, ui):
self.ui = ui
self.image_pathname_list = [] #存储图像路径的列表
self.idx = 0 #当前图像索引
self.image_num = 0 #图像数量
self.disease_cls = -1 #病害类型
"""设置显卡或CPU"""
os.environ["CUDA_VISIBLE_DEVICES"] = '0' # 使用0号设备的GPU
# 若没有GPU 就是用cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = device
"""加载网络模型"""
model_cls_file = './model/model_cls.pth' # 分类
model_seg_file = './model/model_seg.pth' #分割
# 使用深度卷积学习v1类别
net_cls = mobilenet.MobileNetV1()
net_seg = unet.UNetResNet18(4) # 4种类别
net_cls.to(device)
net_seg.to(device)
if model_cls_file.split('.')[-1] == 'pth':
# 从新加载分类模型
net_cls.load_state_dict(torch.load(model_cls_file))
else:
checkpoint = torch.load(model_cls_file)
net_cls.load_state_dict(checkpoint['model_state_dict'])
# 从新加载分割模型
if model_seg_file.split('.')[-1] == 'pth':
net_seg.load_state_dict(torch.load(model_seg_file))
else:
checkpoint = torch.load(model_seg_file)
net_seg.load_state_dict(checkpoint['model_state_dict'])
net_cls.eval() # 开始计算
net_seg.eval() # 分割开始计算
"""输入图像变换方式(组合)"""
transform=transforms.Compose([
transforms.Resize((64, 64)),# 变成64*64
transforms.ToTensor(), # 从H x W x C 变成C x H x W 范围在0-1之间
transforms.Normalize(mean=0.5,std=0.5)]) #使用均值和标准差对图像进行归一化
self.net_cls = net_cls
self.net_seg = net_seg
self.transform = transform
"""选择测试图像所在文件夹(通过“File”菜单下“Open”操作执行)"""
def open_dir(self):
image_dir = QFileDialog.getExistingDirectory(None,"选择文件夹","./")
image_name_list = os.listdir(image_dir)
for image_name in image_name_list:
image_pathname = image_dir + '/' + image_name
self.image_pathname_list.append(image_pathname)
# 排序图片的名字
random.shuffle(self.image_pathname_list)
# 图片数量
self.image_num = len(self.image_pathname_list)
"""开始测试(点击“运行”按钮后开始执行)"""
def run(self):
# 判断索引和图片数量, 数量为0退出
if self.idx >= self.image_num:
exit()
# 获取解决方法
with open('solution.txt')as f:
text = f.readlines()
# 获取图片
image_pathname = self.image_pathname_list[self.idx]
"""显示原图"""
# 按照宽高显示图片,返回重新进行绘图即可自适应窗口
pix = QPixmap(image_pathname).scaled(self.ui.label.width(), self.ui.label.height())
self.ui.label.setPixmap(pix) # 加载图片
# 反向传播时不进行求导, 也就是不进行梯度下降
with torch.no_grad():
print('[', self.idx, '/', self.image_num, ']', image_pathname)
"""读取原图并进行预处理"""
img_rgb = Image.open(image_pathname) # 打开图片,
img_rgb = img_rgb.convert("RGB") # 转换成rgb
# 颜色进行转换np.asarray(img_rgb)要转换的, cv2.COLOR_BGR2RGB转换的格式, 从BGR转换成RGB
img_bgr = cv2.cvtColor(np.asarray(img_rgb), cv2.COLOR_BGR2RGB)
img_transformed = self.transform(img_rgb) # 转换图片
inputs = img_transformed.unsqueeze(0) # 对图像添加一个维度,
inputs = inputs.to(self.device) # 设置cpu或者GPU
"""进行分类"""
output_cls = self.net_cls(inputs)
# 返回1维度的最大值所在的索引, 病害类别
disease_cls = torch.max(output_cls, 1)
disease_cls = disease_cls.indices.item() # 获取具体值
self.disease_cls = disease_cls # 病害类别,比如: 0,2,4
"""进行分割"""
output_seg = self.net_seg(inputs) #使用cpu或者gpu 进行分割
point_cls = torch.max(output_seg, 1).indices # 返回1维度最大值并进行分割
# 在0维度压缩数据,取出1维度, 并且转换为unit8格式
point_cls = point_cls.squeeze(0).cpu().numpy().astype(np.uint8)
point_cls = Image.fromarray(point_cls) # 在cpu缓冲区进行计算
point_cls = point_cls.convert('L') # 根据指定格式进行转换, 并返回转换后的图片
point_cls = np.asarray(point_cls) # 作为数组处理
# 调整图片大小
point_cls = cv2.resize(point_cls, (img_rgb.width, img_rgb.height))
mask1 = point_cls == 1
mask2 = point_cls == 2
mask3 = point_cls == 3
mask1 = mask1.astype(np.uint8)
mask2 = mask2.astype(np.uint8)
mask3 = mask3.astype(np.uint8)
pos_num = np.sum(mask1) + np.sum(mask2) + np.sum(mask3)
mask_rgb1 = cv2.cvtColor(mask1, cv2.COLOR_GRAY2BGR) # 将mask1从灰度图转换成BGR
mask_rgb2 = cv2.cvtColor(mask2, cv2.COLOR_GRAY2BGR)
mask_rgb3 = cv2.cvtColor(mask3, cv2.COLOR_GRAY2BGR)
mask = mask_rgb1 * color_list[0] # 转化成红色
mask += mask_rgb2 * color_list[1] # 转绿色
mask += mask_rgb3 * color_list[2] # 转换蓝色
mask = mask.astype(np.uint8)
img_seg = cv2.addWeighted(img_bgr, alpha, mask, beta, gamma) #将分割结果叠加到原图上,让原来图片特征更明显
cv2.imwrite('./temp.jpg', img_seg) # 写到文件
"""显示疾病面积占比"""
area_ratio = round(100*pos_num/(img_rgb.width * img_rgb.height) ,2) # 求百分比
show_text = str(area_ratio) + '%'
self.ui.lineEdit_2.setText(show_text) # 将结果设置在窗口
"""显示分类和分割结果"""
show_text = Names[disease_cls]
self.ui.lineEdit.setText(show_text) # 病变类别
show_text = text[disease_cls]
self.ui.textEdit.setText(show_text)# 病变的解决方法
# 显示图片
pix = QPixmap('./temp.jpg').scaled(self.ui.label_5.width(), self.ui.label_5.height())
self.ui.label_5.setPixmap(pix)
self.idx += 1
def save(self):
with open('solution.txt') as f:
text = f.readlines()
with open('solution.txt', 'w') as f:
# 创建空字符串
info = self.ui.textEdit.toPlainText()
# 添加换行符
if info[-1] != '\n':
info = info + '\n'
# 病害类型, info 是处理方法
text[self.disease_cls] = info
lines = ''.join(text) # 写到文件中
f.write(lines)
class UserWindow():
def __init__(self):
"""读取用户信息"""
with open('user.txt') as f:
lines = f.readlines()
if len(lines) > 0:
lines = ''.join(lines)
self.name_passward_dict = eval(lines)
else:
self.name_passward_dict = dict()
self.name_passward_dict['root'] = 'root'
"""主窗体"""
self.MainWindow = QMainWindow()
self.ui = leafdisui.Ui_MainWindow()
self.ui.setupUi(self.MainWindow)
# 显示窗口
self.leafdiscls = LeafDisCls(self.ui)
# 打开图片目录
self.ui.actionOpen.triggered.connect(self.leafdiscls.open_dir)
# 运行
self.ui.pushButton.clicked.connect(self.leafdiscls.run)
# 保存结果
self.ui.pushButton_2.clicked.connect(self.leafdiscls.save)
"""登录窗体"""
self.MainWindow_SignIn = QMainWindow()
self.MainWindow_SignIn.setWindowTitle('登录')
self.MainWindow_SignIn.resize(400, 300)
self.pushButton_signin1 = QtWidgets.QPushButton(self.MainWindow_SignIn)
self.pushButton_signin1.setGeometry(QtCore.QRect(110, 200, 80, 30))
self.pushButton_signin1.setText("登录")
self.pushButton_signin2 = QtWidgets.QPushButton(self.MainWindow_SignIn)
self.pushButton_signin2.setGeometry(QtCore.QRect(200, 200, 80, 30))
self.pushButton_signin2.setText("注册")
self.label1_signin = QtWidgets.QLabel(self.MainWindow_SignIn)
self.label1_signin.setGeometry(QtCore.QRect(80, 50, 80, 30))
self.label1_signin.setText("用户名:")
self.line_edit1 = QtWidgets.QLineEdit(self.MainWindow_SignIn)
self.line_edit1.setGeometry(QtCore.QRect(140, 52, 120, 25))
self.label1_signin = QtWidgets.QLabel(self.MainWindow_SignIn)
self.label1_signin.setGeometry(QtCore.QRect(80, 100, 80, 30))
self.label1_signin.setText("密 码:")
self.line_edit2 = QtWidgets.QLineEdit(self.MainWindow_SignIn)
self.line_edit2.setGeometry(QtCore.QRect(140, 102, 120, 25))
self.line_edit2.setEchoMode(QLineEdit.Password)
# 登录
self.pushButton_signin1.clicked.connect(self.sin_in)
# 注册
self.pushButton_signin2.clicked.connect(self.register)
self.MainWindow_SignIn.show()
"""用户登录"""
def sin_in(self):
# 输入框不为空
if self.line_edit1.text() != '' and self.line_edit2.text() != '':
username = self.line_edit1.text()# 获取用户名和密码
passward = self.line_edit2.text()
# 判断用户户名是否存在
if username in self.name_passward_dict.keys():
# 用户名和密码一样
if self.name_passward_dict[username] == passward:
self.MainWindow_SignIn.close() #关闭登录窗口
if username != 'root':# 若用户不是root,就只读
self.ui.textEdit.setReadOnly(True)
self.ui.pushButton_2.setEnabled(False) # 不能修改
self.MainWindow.show()
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '密码错误!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
# 用户名不存在
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户不存在,请先注册!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户名或密码不能为空!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
"""用户注册"""
def register(self):
if self.line_edit1.text() != '' and self.line_edit2.text() != '':
username = self.line_edit1.text()
passward = self.line_edit2.text()
if username not in self.name_passward_dict.keys():
self.name_passward_dict[username] = passward
with open('user.txt', 'w') as f:
f.write(str(self.name_passward_dict))
msg_box = QMessageBox(QMessageBox.Warning, '提示', '注册成功!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户已存在!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
else:
msg_box = QMessageBox(QMessageBox.Warning, '错误', '用户名或密码不能为空!')
msg_box.exec_()
self.line_edit1.setText('')
self.line_edit2.setText('')
if __name__ == '__main__':
app = QApplication(sys.argv)
MyWindow = UserWindow()
sys.exit(app.exec_())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。