加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
app.py 58.18 KB
一键复制 编辑 原始数据 按行查看 历史
kerlomz 提交于 2020-11-15 21:49 . 兼容tf2+
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558
import os
import re
import math
import sys
import shutil
import json
import traceback
import PIL.Image as PilImage
import threading
import tkinter as tk
from tkinter import messagebox
from tkinter import ttk
from tkinter import filedialog
from constants import *
from config import ModelConfig, OUTPUT_SHAPE1_MAP, NETWORK_MAP, DataAugmentationEntity, PretreatmentEntity, get_version
from make_dataset import DataSets
from predict_testing import Predict
from trains import Trains
from category import category_extract, SIMPLE_CATEGORY_MODEL
from utils.category_frequency_statistics import fetch_category_list
from gui.utils import LayoutGUI
from gui.data_augmentation import DataAugmentationDialog
from gui.pretreatment import PretreatmentDialog
NOT_EDITABLE_MSG = "ONLY SUPPORT MODIFICATION FROM FILE"
class Wizard:
job: threading.Thread
current_task: Trains = None
is_task_running: bool = False
data_augmentation_entity = DataAugmentationEntity()
pretreatment_entity = PretreatmentEntity()
extract_regex = ".*?(?=_)"
label_split = ""
model_conf: ModelConfig = None
def __init__(self, parent: tk.Tk):
self.layout = {
'global': {
'start': {'x': 15, 'y': 20},
'space': {'x': 15, 'y': 25},
'tiny_space': {'x': 5, 'y': 10}
}
}
self.parent = parent
self.parent.iconbitmap(Wizard.resource_path("resource/icon.ico"))
self.current_project: str = ""
self.project_root_path = "./projects"
if not os.path.exists(self.project_root_path):
os.makedirs(self.project_root_path)
self.parent.title('Eve-DL Trainer v1({})'.format(get_version()))
self.parent.resizable(width=False, height=False)
self.window_width = 815
self.window_height = 700
self.layout_utils = LayoutGUI(self.layout, self.window_width)
screenwidth = self.parent.winfo_screenwidth()
screenheight = self.parent.winfo_screenheight()
size = '%dx%d+%d+%d' % (
self.window_width,
self.window_height,
(screenwidth - self.window_width) / 2,
(screenheight - self.window_height) / 2
)
self.parent.bind('<Button-1>', lambda x: self.blank_click(x))
# ============================= Menu 1 =====================================
self.menubar = tk.Menu(self.parent)
self.data_menu = tk.Menu(self.menubar, tearoff=False)
self.help_menu = tk.Menu(self.menubar, tearoff=False)
self.system_menu = tk.Menu(self.menubar, tearoff=False)
self.edit_var = tk.DoubleVar()
self.label_from_var = tk.StringVar()
# self.memory_usage_menu = tk.Menu(self.menubar, tearoff=False)
# self.memory_usage_menu.add_radiobutton(label="50%", variable=self.edit_var, value=0.5)
# self.memory_usage_menu.add_radiobutton(label="60%", variable=self.edit_var, value=0.6)
# self.memory_usage_menu.add_radiobutton(label="70%", variable=self.edit_var, value=0.7)
# self.memory_usage_menu.add_radiobutton(label="80%", variable=self.edit_var, value=0.8)
self.label_from_menu = tk.Menu(self.menubar, tearoff=False)
self.label_from_menu.add_radiobutton(label="FileName", variable=self.label_from_var, value='FileName')
self.label_from_menu.add_radiobutton(label="TXT", variable=self.label_from_var, value='TXT')
self.menubar.add_cascade(label="System", menu=self.system_menu)
# self.system_menu.add_cascade(label="Memory Usage", menu=self.memory_usage_menu)
self.data_menu.add_command(label="Data Augmentation", command=lambda: self.popup_data_augmentation())
self.data_menu.add_command(label="Pretreatment", command=lambda: self.popup_pretreatment())
self.data_menu.add_separator()
self.data_menu.add_command(label="Clear Dataset", command=lambda: self.clear_dataset())
self.data_menu.add_separator()
self.data_menu.add_cascade(label="Label From", menu=self.label_from_menu)
self.data_menu.add_command(label="Fetch Category", command=lambda: self.fetch_category())
self.menubar.add_cascade(label="Data", menu=self.data_menu)
self.help_menu.add_command(label="About", command=lambda: self.popup_about())
self.menubar.add_cascade(label="Help", menu=self.help_menu)
self.parent.config(menu=self.menubar)
# ============================= Group 1 =====================================
self.label_frame_source = ttk.Labelframe(self.parent, text='Sample Source')
self.label_frame_source.place(
x=self.layout['global']['start']['x'],
y=self.layout['global']['start']['y'],
width=790,
height=150
)
# 训练集源路径 - 标签
self.dataset_train_path_text = ttk.Label(self.parent, text='Training Path', anchor=tk.W)
self.layout_utils.inside_widget(
src=self.dataset_train_path_text,
target=self.label_frame_source,
width=90,
height=20
)
# 训练集源路径 - 输入控件
self.source_train_path_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
self.layout_utils.next_to_widget(
src=self.source_train_path_listbox,
target=self.dataset_train_path_text,
width=600,
height=50,
tiny_space=True
)
self.source_train_path_listbox.bind(
sequence="<Delete>",
func=lambda x: self.listbox_delete_item_callback(x, self.source_train_path_listbox)
)
self.listbox_scrollbar(self.source_train_path_listbox)
# 训练集源路径 - 按钮
self.btn_browse_train = ttk.Button(
self.parent, text='Browse', command=lambda: self.browse_dataset(DatasetType.Directory, RunMode.Trains)
)
self.layout_utils.next_to_widget(
src=self.btn_browse_train,
target=self.source_train_path_listbox,
width=60,
height=24,
tiny_space=True
)
# 验证集源路径 - 标签
label_edge = self.layout_utils.object_edge_info(self.dataset_train_path_text)
widget_edge = self.layout_utils.object_edge_info(self.source_train_path_listbox)
self.dataset_validation_path_text = ttk.Label(self.parent, text='Validation Path', anchor=tk.W)
self.dataset_validation_path_text.place(
x=label_edge['x'],
y=widget_edge['edge_y'] + self.layout['global']['space']['y'] / 2,
width=90,
height=20
)
# 验证集源路径 - 输入控件
self.source_validation_path_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
self.layout_utils.next_to_widget(
src=self.source_validation_path_listbox,
target=self.dataset_validation_path_text,
width=600,
height=50,
tiny_space=True
)
self.source_validation_path_listbox.bind(
sequence="<Delete>",
func=lambda x: self.listbox_delete_item_callback(x, self.source_validation_path_listbox)
)
self.listbox_scrollbar(self.source_validation_path_listbox)
# 训练集源路径 - 按钮
self.btn_browse_validation = ttk.Button(
self.parent, text='Browse', command=lambda: self.browse_dataset(DatasetType.Directory, RunMode.Validation)
)
self.layout_utils.next_to_widget(
src=self.btn_browse_validation,
target=self.source_validation_path_listbox,
width=60,
height=24,
tiny_space=True
)
# ============================= Group 2 =====================================
self.label_frame_neu = ttk.Labelframe(self.parent, text='Neural Network')
self.layout_utils.below_widget(
src=self.label_frame_neu,
target=self.label_frame_source,
width=790,
height=120,
tiny_space=False
)
# 最大标签数目 - 标签
self.label_num_text = ttk.Label(self.parent, text='Label Num', anchor=tk.W)
self.layout_utils.inside_widget(
src=self.label_num_text,
target=self.label_frame_neu,
width=65,
height=20,
)
# 最大标签数目 - 滚动框
self.label_num_spin = ttk.Spinbox(self.parent, from_=1, to=12)
self.label_num_spin.set(1)
self.layout_utils.next_to_widget(
src=self.label_num_spin,
target=self.label_num_text,
width=50,
height=20,
tiny_space=True
)
# 图像通道 - 标签
self.channel_text = ttk.Label(self.parent, text='Channel', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.channel_text,
target=self.label_num_spin,
width=50,
height=20,
tiny_space=False
)
# 图像通道 - 下拉框
self.comb_channel = ttk.Combobox(self.parent, values=(3, 1), state='readonly')
self.comb_channel.current(1)
self.layout_utils.next_to_widget(
src=self.comb_channel,
target=self.channel_text,
width=38,
height=20,
tiny_space=True
)
# 卷积层 - 标签
self.neu_cnn_text = ttk.Label(self.parent, text='CNN Layer', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.neu_cnn_text,
target=self.comb_channel,
width=65,
height=20,
tiny_space=False
)
# 卷积层 - 下拉框
self.comb_neu_cnn = ttk.Combobox(self.parent, values=[_.name for _ in CNNNetwork], state='readonly')
self.comb_neu_cnn.current(0)
self.layout_utils.next_to_widget(
src=self.comb_neu_cnn,
target=self.neu_cnn_text,
width=80,
height=20,
tiny_space=True
)
# 循环层 - 标签
self.neu_recurrent_text = ttk.Label(self.parent, text='Recurrent Layer', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.neu_recurrent_text,
target=self.comb_neu_cnn,
width=95,
height=20,
tiny_space=False
)
# 循环层 - 下拉框
self.comb_recurrent = ttk.Combobox(self.parent, values=[_.name for _ in RecurrentNetwork], state='readonly')
self.comb_recurrent.current(1)
self.layout_utils.next_to_widget(
src=self.comb_recurrent,
target=self.neu_recurrent_text,
width=112,
height=20,
tiny_space=True
)
self.comb_recurrent.bind("<<ComboboxSelected>>", lambda x: self.auto_loss(x))
# 循环层单元数 - 标签
self.units_num_text = ttk.Label(self.parent, text='UnitsNum', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.units_num_text,
target=self.comb_recurrent,
width=60,
height=20,
tiny_space=False
)
# 循环层单元数 - 下拉框
self.units_num_spin = ttk.Spinbox(self.parent, from_=16, to=512, increment=16, wrap=True)
self.units_num_spin.set(64)
self.layout_utils.next_to_widget(
src=self.units_num_spin,
target=self.units_num_text,
width=55,
height=20,
tiny_space=True
)
# 损失函数 - 标签
self.loss_func_text = ttk.Label(self.parent, text='Loss Function', anchor=tk.W)
self.layout_utils.below_widget(
src=self.loss_func_text,
target=self.label_num_text,
width=85,
height=20,
tiny_space=True
)
# 损失函数 - 下拉框
self.comb_loss = ttk.Combobox(self.parent, values=[_.name for _ in LossFunction], state='readonly')
self.comb_loss.current(1)
self.layout_utils.next_to_widget(
src=self.comb_loss,
target=self.loss_func_text,
width=101,
height=20,
tiny_space=True
)
# 优化器 - 标签
self.optimizer_text = ttk.Label(self.parent, text='Optimizer', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.optimizer_text,
target=self.comb_loss,
width=60,
height=20,
tiny_space=False
)
# 优化器 - 下拉框
self.comb_optimizer = ttk.Combobox(self.parent, values=[_.name for _ in Optimizer], state='readonly')
self.comb_optimizer.current(0)
self.layout_utils.next_to_widget(
src=self.comb_optimizer,
target=self.optimizer_text,
width=88,
height=20,
tiny_space=True
)
# 学习率 - 标签
self.learning_rate_text = ttk.Label(self.parent, text='Learning Rate', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.learning_rate_text,
target=self.comb_optimizer,
width=85,
height=20,
tiny_space=False
)
# 学习率 - 滚动框
self.learning_rate_spin = ttk.Spinbox(self.parent, from_=0.00001, to=0.1, increment='0.0001')
self.learning_rate_spin.set(0.001)
self.layout_utils.next_to_widget(
src=self.learning_rate_spin,
target=self.learning_rate_text,
width=67,
height=20,
tiny_space=True
)
# Resize - 标签
self.resize_text = ttk.Label(self.parent, text='Resize', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.resize_text,
target=self.learning_rate_spin,
width=36,
height=20,
tiny_space=False
)
# Resize - 输入框
self.resize_val = tk.StringVar()
self.resize_val.set('[150, 50]')
self.resize_entry = ttk.Entry(self.parent, textvariable=self.resize_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.resize_entry,
target=self.resize_text,
width=60,
height=20,
tiny_space=True
)
# Size - 标签
self.size_text = ttk.Label(self.parent, text='Size', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.size_text,
target=self.resize_entry,
width=30,
height=20,
tiny_space=False
)
# Size - 输入框
self.size_val = tk.StringVar()
self.size_val.set('[-1, -1]')
self.size_entry = ttk.Entry(self.parent, textvariable=self.size_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.size_entry,
target=self.size_text,
width=60,
height=20,
tiny_space=True
)
# 类别 - 标签
self.category_text = ttk.Label(self.parent, text='Category', anchor=tk.W)
self.layout_utils.below_widget(
src=self.category_text,
target=self.loss_func_text,
width=72,
height=20,
tiny_space=True
)
# 类别 - 下拉框
self.comb_category = ttk.Combobox(self.parent, values=(
'CUSTOMIZED',
'NUMERIC',
'ALPHANUMERIC',
'ALPHANUMERIC_LOWER',
'ALPHANUMERIC_UPPER',
'ALPHABET_LOWER',
'ALPHABET_UPPER',
'ALPHABET',
'ARITHMETIC',
'ARITHMETIC_MIX_ALPHA_LOWER',
'FLOAT',
'CHS_3500',
'ALPHANUMERIC_CHS_3500_LOWER',
'DOCUMENT_OCR',
), state='readonly')
self.comb_category.current(1)
self.comb_category.bind("<<ComboboxSelected>>", lambda x: self.comb_category_callback(x))
self.layout_utils.next_to_widget(
src=self.comb_category,
target=self.category_text,
width=225,
height=20,
tiny_space=True
)
# 类别 - 自定义输入框
self.category_val = tk.StringVar()
self.category_val.set('')
self.category_entry = ttk.Entry(self.parent, textvariable=self.category_val, justify=tk.LEFT, state=tk.DISABLED)
self.layout_utils.next_to_widget(
src=self.category_entry,
target=self.comb_category,
width=440,
height=20,
tiny_space=False
)
# ============================= Group 3 =====================================
self.label_frame_train = ttk.Labelframe(self.parent, text='Training Configuration')
self.layout_utils.below_widget(
src=self.label_frame_train,
target=self.label_frame_neu,
width=790,
height=60,
tiny_space=True
)
# 任务完成标准 - 准确率 - 标签
self.end_acc_text = ttk.Label(self.parent, text='End Accuracy', anchor=tk.W)
self.layout_utils.inside_widget(
src=self.end_acc_text,
target=self.label_frame_train,
width=85,
height=20,
)
# 任务完成标准 - 准确率 - 输入框
self.end_acc_val = tk.DoubleVar()
self.end_acc_val.set(0.95)
self.end_acc_entry = ttk.Entry(self.parent, textvariable=self.end_acc_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.end_acc_entry,
target=self.end_acc_text,
width=56,
height=20,
tiny_space=True
)
# 任务完成标准 - 平均损失 - 标签
self.end_cost_text = ttk.Label(self.parent, text='End Cost', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.end_cost_text,
target=self.end_acc_entry,
width=60,
height=20,
tiny_space=False
)
# 任务完成标准 - 平均损失 - 输入框
self.end_cost_val = tk.DoubleVar()
self.end_cost_val.set(0.5)
self.end_cost_entry = ttk.Entry(self.parent, textvariable=self.end_cost_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.end_cost_entry,
target=self.end_cost_text,
width=58,
height=20,
tiny_space=True
)
# 任务完成标准 - 循环轮次 - 标签
self.end_epochs_text = ttk.Label(self.parent, text='End Epochs', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.end_epochs_text,
target=self.end_cost_entry,
width=72,
height=20,
tiny_space=False
)
# 任务完成标准 - 循环轮次 - 输入框
self.end_epochs_spin = ttk.Spinbox(self.parent, from_=0, to=10000)
self.end_epochs_spin.set(2)
self.layout_utils.next_to_widget(
src=self.end_epochs_spin,
target=self.end_epochs_text,
width=50,
height=20,
tiny_space=True
)
# 训练批次大小 - 标签
self.batch_size_text = ttk.Label(self.parent, text='Train BatchSize', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.batch_size_text,
target=self.end_epochs_spin,
width=90,
height=20,
tiny_space=False
)
# 训练批次大小 - 输入框
self.batch_size_val = tk.IntVar()
self.batch_size_val.set(64)
self.batch_size_entry = ttk.Entry(self.parent, textvariable=self.batch_size_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.batch_size_entry,
target=self.batch_size_text,
width=40,
height=20,
tiny_space=True
)
# 验证批次大小 - 标签
self.validation_batch_size_text = ttk.Label(self.parent, text='Validation BatchSize', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.validation_batch_size_text,
target=self.batch_size_entry,
width=120,
height=20,
tiny_space=False
)
# 验证批次大小 - 输入框
self.validation_batch_size_val = tk.IntVar()
self.validation_batch_size_val.set(300)
self.validation_batch_size_entry = ttk.Entry(self.parent, textvariable=self.validation_batch_size_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.validation_batch_size_entry,
target=self.validation_batch_size_text,
width=40,
height=20,
tiny_space=True
)
# ============================= Group 5 =====================================
self.label_frame_project = ttk.Labelframe(self.parent, text='Project Configuration')
self.layout_utils.below_widget(
src=self.label_frame_project,
target=self.label_frame_train,
width=790,
height=60,
tiny_space=True
)
# 项目名 - 标签
self.project_name_text = ttk.Label(self.parent, text='Project Name', anchor=tk.W)
self.layout_utils.inside_widget(
src=self.project_name_text,
target=self.label_frame_project,
width=90,
height=20
)
# 项目名 - 下拉输入框
self.comb_project_name = ttk.Combobox(self.parent)
self.layout_utils.next_to_widget(
src=self.comb_project_name,
target=self.project_name_text,
width=430,
height=20,
tiny_space=True
)
self.comb_project_name.bind(
sequence="<Return>",
func=lambda x: self.project_name_fill_callback(x)
)
self.comb_project_name.bind(
sequence="<Button-1>",
func=lambda x: self.fetch_projects()
)
def read_conf(event):
threading.Thread(target=self.read_conf).start()
self.comb_project_name.bind("<<ComboboxSelected>>", read_conf)
# 保存配置 - 按钮
self.btn_save_conf = ttk.Button(
self.parent, text='Save Configuration', command=lambda: self.save_conf()
)
self.layout_utils.next_to_widget(
src=self.btn_save_conf,
target=self.comb_project_name,
width=130,
height=24,
tiny_space=False,
offset_y=-2
)
# 删除项目 - 按钮
self.btn_delete = ttk.Button(
self.parent, text='Delete', command=lambda: self.delete_project()
)
self.layout_utils.next_to_widget(
src=self.btn_delete,
target=self.btn_save_conf,
width=80,
height=24,
tiny_space=False,
)
# ============================= Group 6 =====================================
self.label_frame_dataset = ttk.Labelframe(
self.parent, text='Sample Dataset'
)
self.layout_utils.below_widget(
src=self.label_frame_dataset,
target=self.label_frame_project,
width=790,
height=170,
tiny_space=True
)
# 附加训练集 - 按钮
self.btn_attach_dataset = ttk.Button(
self.parent,
text='Attach Dataset',
command=lambda: self.attach_dataset()
)
self.layout_utils.inside_widget(
src=self.btn_attach_dataset,
target=self.label_frame_dataset,
width=120,
height=24,
)
# 附加训练集 - 显示框
self.attach_dataset_val = tk.StringVar()
self.attach_dataset_val.set('')
self.attach_dataset_entry = ttk.Entry(
self.parent, textvariable=self.attach_dataset_val, justify=tk.LEFT, state=tk.DISABLED
)
self.layout_utils.next_to_widget(
src=self.attach_dataset_entry,
target=self.btn_attach_dataset,
width=420,
height=24,
tiny_space=True
)
# 验证集数目 - 标签
self.validation_num_text = ttk.Label(self.parent, text='Validation Set Num', anchor=tk.W)
self.layout_utils.next_to_widget(
src=self.validation_num_text,
target=self.attach_dataset_entry,
width=120,
height=20,
tiny_space=False,
offset_y=2
)
# 验证集数目 - 输入框
self.validation_num_val = tk.IntVar()
self.validation_num_val.set(300)
self.validation_num_entry = ttk.Entry(self.parent, textvariable=self.validation_num_val, justify=tk.LEFT)
self.layout_utils.next_to_widget(
src=self.validation_num_entry,
target=self.validation_num_text,
width=71,
height=20,
tiny_space=True
)
# 训练集路径 - 标签
self.dataset_train_path_text = ttk.Label(self.parent, text='Training Dataset', anchor=tk.W)
self.layout_utils.below_widget(
src=self.dataset_train_path_text,
target=self.btn_attach_dataset,
width=100,
height=20,
tiny_space=False
)
# 训练集路径 - 列表框
self.dataset_train_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
self.layout_utils.next_to_widget(
src=self.dataset_train_listbox,
target=self.dataset_train_path_text,
width=640,
height=36,
tiny_space=False
)
self.dataset_train_listbox.bind(
sequence="<Delete>",
func=lambda x: self.listbox_delete_item_callback(x, self.dataset_train_listbox)
)
self.listbox_scrollbar(self.dataset_train_listbox)
# 验证集路径 - 标签
label_edge = self.layout_utils.object_edge_info(self.dataset_train_path_text)
widget_edge = self.layout_utils.object_edge_info(self.dataset_train_listbox)
self.dataset_validation_path_text = ttk.Label(self.parent, text='Validation Dataset', anchor=tk.W)
self.dataset_validation_path_text.place(
x=label_edge['x'],
y=widget_edge['edge_y'] + self.layout['global']['space']['y'] / 2,
width=100,
height=20
)
# 验证集路径 - 下拉输入框
self.dataset_validation_listbox = tk.Listbox(self.parent, font=('微软雅黑', 9))
self.layout_utils.next_to_widget(
src=self.dataset_validation_listbox,
target=self.dataset_validation_path_text,
width=640,
height=36,
tiny_space=False
)
self.dataset_validation_listbox.bind(
sequence="<Delete>",
func=lambda x: self.listbox_delete_item_callback(x, self.dataset_validation_listbox)
)
self.listbox_scrollbar(self.dataset_validation_listbox)
self.sample_map = {
DatasetType.Directory: {
RunMode.Trains: self.source_train_path_listbox,
RunMode.Validation: self.source_validation_path_listbox
},
DatasetType.TFRecords: {
RunMode.Trains: self.dataset_train_listbox,
RunMode.Validation: self.dataset_validation_listbox
}
}
# 开始训练 - 按钮
self.btn_training = ttk.Button(self.parent, text='Start Training', command=lambda: self.start_training())
self.layout_utils.widget_from_right(
src=self.btn_training,
target=self.label_frame_dataset,
width=120,
height=24,
tiny_space=True
)
# 终止训练 - 按钮
self.btn_stop = ttk.Button(self.parent, text='Stop', command=lambda: self.stop_training())
self.button_state(self.btn_stop, tk.DISABLED)
self.layout_utils.before_widget(
src=self.btn_stop,
target=self.btn_training,
width=60,
height=24,
tiny_space=True
)
# 编译模型 - 按钮
self.btn_compile = ttk.Button(self.parent, text='Compile', command=lambda: self.compile())
self.layout_utils.before_widget(
src=self.btn_compile,
target=self.btn_stop,
width=80,
height=24,
tiny_space=True
)
# 打包训练集 - 按钮
self.btn_make_dataset = ttk.Button(self.parent, text='Make Dataset', command=lambda: self.make_dataset())
self.layout_utils.before_widget(
src=self.btn_make_dataset,
target=self.btn_compile,
width=120,
height=24,
tiny_space=True
)
# 清除训练记录 - 按钮
self.btn_reset_history = ttk.Button(
self.parent, text='Reset History', command=lambda: self.reset_history()
)
self.layout_utils.before_widget(
src=self.btn_reset_history,
target=self.btn_make_dataset,
width=120,
height=24,
tiny_space=True
)
# 预测 - 按钮
self.btn_testing = ttk.Button(
self.parent, text='Testing', command=lambda: self.testing_model()
)
self.layout_utils.before_widget(
src=self.btn_testing,
target=self.btn_reset_history,
width=80,
height=24,
tiny_space=True
)
self.parent.geometry(size)
@staticmethod
def threading_exec(func, *args) -> threading.Thread:
th = threading.Thread(target=func, args=args)
th.setDaemon(True)
th.start()
return th
def popup_data_augmentation(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
data_augmentation = DataAugmentationDialog()
data_augmentation.read_conf(self.data_augmentation_entity)
def popup_pretreatment(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
pretreatment = PretreatmentDialog()
pretreatment.read_conf(self.pretreatment_entity)
@staticmethod
def listbox_scrollbar(listbox: tk.Listbox):
y_scrollbar = tk.Scrollbar(
listbox, command=listbox.yview
)
y_scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
listbox.config(yscrollcommand=y_scrollbar.set)
def blank_click(self, event):
if self.current_project != self.comb_project_name.get():
self.project_name_fill_callback(event)
def project_name_fill_callback(self, event):
suffix = '-{}-{}-H{}-{}-C{}'.format(
self.comb_neu_cnn.get(),
self.comb_recurrent.get(),
self.units_num_spin.get(),
self.comb_loss.get(),
self.comb_channel.get(),
)
current_project_name = self.comb_project_name.get()
if len(current_project_name) > 0 and current_project_name not in self.project_names:
self.extract_regex = ".*?(?=_)"
self.label_from_var.set('FileName')
self.sample_map[DatasetType.Directory][RunMode.Trains].delete(0, tk.END)
self.sample_map[DatasetType.Directory][RunMode.Validation].delete(0, tk.END)
self.category_val.set("")
if not current_project_name.endswith(suffix):
self.comb_project_name.insert(tk.END, suffix)
self.current_project = self.comb_project_name.get()
self.update_dataset_files_path(mode=RunMode.Trains)
self.update_dataset_files_path(mode=RunMode.Validation)
self.data_augmentation_entity = DataAugmentationEntity()
self.pretreatment_entity = PretreatmentEntity()
@property
def project_path(self):
if not self.current_project:
return None
project_path = "{}/{}".format(self.project_root_path, self.current_project)
if not os.path.exists(project_path):
os.makedirs(project_path)
return project_path
def update_dataset_files_path(self, mode: RunMode):
dataset_name = "dataset/{}.0.tfrecords".format(mode.value)
dataset_path = os.path.join(self.project_path, dataset_name)
dataset_path = dataset_path.replace("\\", '/')
self.sample_map[DatasetType.TFRecords][mode].delete(0, tk.END)
self.sample_map[DatasetType.TFRecords][mode].insert(tk.END, dataset_path)
self.save_conf()
def attach_dataset(self):
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
filename = filedialog.askdirectory()
if not filename:
return
model_conf = ModelConfig(self.current_project)
if not self.check_dataset(model_conf):
return
self.attach_dataset_val.set(filename)
self.sample_map[DatasetType.Directory][RunMode.Trains].insert(tk.END, filename)
self.button_state(self.btn_attach_dataset, tk.DISABLED)
for mode in [RunMode.Trains, RunMode.Validation]:
attached_dataset_name = model_conf.dataset_increasing_name(mode)
attached_dataset_name = "dataset/{}".format(attached_dataset_name)
attached_dataset_path = os.path.join(self.project_path, attached_dataset_name)
attached_dataset_path = attached_dataset_path.replace("\\", '/')
if mode == RunMode.Validation and self.validation_num_val.get() == 0:
continue
self.sample_map[DatasetType.TFRecords][mode].insert(tk.END, attached_dataset_path)
self.save_conf()
model_conf = ModelConfig(self.current_project)
self.threading_exec(
lambda: DataSets(model_conf).make_dataset(
trains_path=filename,
is_add=True,
callback=lambda: self.button_state(self.btn_attach_dataset, tk.NORMAL),
msg=lambda x: tk.messagebox.showinfo('Attach Dataset Status', x)
)
)
pass
@staticmethod
def button_state(btn: ttk.Button, state: str):
btn['state'] = state
def delete_project(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please select a project to delete."
)
return
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
project_path = "./projects/{}".format(self.current_project)
try:
shutil.rmtree(project_path)
except Exception as e:
messagebox.showerror(
"Error!", json.dumps(e.args, ensure_ascii=False)
)
messagebox.showinfo(
"Error!", "Delete successful!"
)
self.comb_project_name.delete(0, tk.END)
def reset_history(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please select a project first."
)
return
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
project_history_path = "./projects/{}/model".format(self.current_project)
try:
shutil.rmtree(project_history_path)
except Exception as e:
messagebox.showerror(
"Error!", json.dumps(e.args, ensure_ascii=False)
)
messagebox.showinfo(
"Error!", "Delete history successful!"
)
def testing_model(self):
filename = filedialog.askdirectory()
if not filename:
return
filename = filename.replace("\\", "/")
predict = Predict(project_name=self.current_project)
predict.testing(image_dir=filename, limit=self.validation_batch_size)
def clear_dataset(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please select a project first."
)
return
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
project_history_path = "./projects/{}/dataset".format(self.current_project)
try:
shutil.rmtree(project_history_path)
self.dataset_train_listbox.delete(1, tk.END)
self.dataset_validation_listbox.delete(1, tk.END)
except Exception as e:
messagebox.showerror(
"Error!", json.dumps(e.args, ensure_ascii=False)
)
messagebox.showinfo(
"Error!", "Clear dataset successful!"
)
@staticmethod
def popup_about():
messagebox.showinfo("About", "Eve-DL Trainer CORE_VERSION({})\n\nAuthor's mailbox: kerlomz@gmail.com\n\nQQ Group: 857149419".format(get_version()))
def auto_loss(self, event):
if self.comb_recurrent.get() == 'NoRecurrent':
self.comb_loss.set("CrossEntropy")
@staticmethod
def get_param(src: dict, key, default=None):
result = src.get(key)
return result if result else default
def read_conf(self):
print('Reading configuration...')
selected = self.comb_project_name.get()
self.current_project = selected
model_conf = ModelConfig(selected)
self.edit_var.set(model_conf.memory_usage)
self.size_val.set("[{}, {}]".format(model_conf.image_width, model_conf.image_height))
self.resize_val.set(json.dumps(model_conf.resize))
self.source_train_path_listbox.delete(0, tk.END)
self.source_validation_path_listbox.delete(0, tk.END)
self.dataset_validation_listbox.delete(0, tk.END)
self.dataset_train_listbox.delete(0, tk.END)
for source_train in self.get_param(model_conf.trains_path, DatasetType.Directory, default=[]):
self.source_train_path_listbox.insert(tk.END, source_train)
for source_validation in self.get_param(model_conf.validation_path, DatasetType.Directory, default=[]):
self.source_validation_path_listbox.insert(tk.END, source_validation)
self.label_num_spin.set(model_conf.max_label_num)
self.comb_channel.set(model_conf.image_channel)
self.comb_neu_cnn.set(model_conf.neu_cnn_param)
self.comb_recurrent.set(model_conf.neu_recurrent_param)
self.units_num_spin.set(model_conf.units_num)
self.comb_loss.set(model_conf.loss_func_param)
self.extract_regex = model_conf.extract_regex
self.label_split = model_conf.label_split
self.label_from_var.set(model_conf.label_from.value)
self.comb_optimizer.set(model_conf.neu_optimizer_param)
self.learning_rate_spin.set(float(model_conf.trains_learning_rate))
self.end_acc_val.set(model_conf.trains_end_acc)
self.end_cost_val.set(model_conf.trains_end_cost)
self.end_epochs_spin.set(model_conf.trains_end_epochs)
self.batch_size_val.set(model_conf.batch_size)
self.validation_batch_size_val.set(model_conf.validation_batch_size)
self.validation_num_val.set(model_conf.validation_set_num)
self.data_augmentation_entity.binaryzation = model_conf.da_binaryzation
self.data_augmentation_entity.median_blur = model_conf.da_median_blur
self.data_augmentation_entity.gaussian_blur = model_conf.da_gaussian_blur
self.data_augmentation_entity.equalize_hist = model_conf.da_equalize_hist
self.data_augmentation_entity.laplace = model_conf.da_laplace
self.data_augmentation_entity.warp_perspective = model_conf.da_warp_perspective
self.data_augmentation_entity.rotate = model_conf.da_rotate
self.data_augmentation_entity.sp_noise = model_conf.da_sp_noise
self.data_augmentation_entity.brightness = model_conf.da_brightness
self.data_augmentation_entity.hue = model_conf.da_hue
self.data_augmentation_entity.saturation = model_conf.da_saturation
self.data_augmentation_entity.gamma = model_conf.da_gamma
self.data_augmentation_entity.channel_swap = model_conf.da_channel_swap
self.data_augmentation_entity.random_blank = model_conf.da_random_blank
self.data_augmentation_entity.random_transition = model_conf.da_random_transition
self.data_augmentation_entity.random_captcha = model_conf.da_random_captcha
self.pretreatment_entity.binaryzation = model_conf.pre_binaryzation
self.pretreatment_entity.replace_transparent = model_conf.pre_replace_transparent
self.pretreatment_entity.horizontal_stitching = model_conf.pre_horizontal_stitching
self.pretreatment_entity.concat_frames = model_conf.pre_concat_frames
self.pretreatment_entity.blend_frames = model_conf.pre_blend_frames
self.pretreatment_entity.exec_map = model_conf.pre_exec_map
for dataset_validation in self.get_param(model_conf.validation_path, DatasetType.TFRecords, default=[]):
self.dataset_validation_listbox.insert(tk.END, dataset_validation)
for dataset_train in self.get_param(model_conf.trains_path, DatasetType.TFRecords, default=[]):
self.dataset_train_listbox.insert(tk.END, dataset_train)
# print('Loading category configuration...')
if isinstance(model_conf.category_param, list):
self.category_entry['state'] = tk.DISABLED
self.comb_category.set('CUSTOMIZED')
if len(model_conf.category_param) > 1000:
self.category_val.set(NOT_EDITABLE_MSG)
else:
self.category_val.set(model_conf.category_param_text)
self.category_entry['state'] = tk.NORMAL
else:
self.category_val.set("")
self.category_entry['state'] = tk.DISABLED
self.comb_category.set(model_conf.category_param)
# print('Loading configuration is completed.')
self.model_conf = model_conf
return self.model_conf
@property
def validation_batch_size(self):
# if self.dataset_validation_listbox.size() > 1:
return self.validation_batch_size_val.get()
# else:
# return min(self.validation_batch_size_val.get(), self.validation_num_val.get())
@property
def device_usage(self):
return self.edit_var.get()
def save_conf(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
model_conf = ModelConfig(
project_name=self.current_project,
MemoryUsage=self.device_usage,
CNNNetwork=self.neu_cnn,
RecurrentNetwork=self.neu_recurrent,
UnitsNum=self.units_num_spin.get(),
Optimizer=self.optimizer,
LossFunction=self.loss_func,
Decoder=self.comb_loss.get(),
ModelName=self.current_project,
ModelField=ModelField.Image.value,
ModelScene=ModelScene.Classification.value,
Category=self.category,
Resize=self.resize,
ImageChannel=self.comb_channel.get(),
ImageWidth=self.image_width,
ImageHeight=self.image_height,
MaxLabelNum=self.label_num_spin.get(),
AutoPadding=True,
ReplaceTransparent=False,
HorizontalStitching=False,
OutputSplit='',
LabelFrom=self.label_from_var.get(),
ExtractRegex=self.extract_regex,
LabelSplit=self.label_split,
DatasetTrainsPath=self.dataset_value(
dataset_type=DatasetType.TFRecords, mode=RunMode.Trains
),
DatasetValidationPath=self.dataset_value(
dataset_type=DatasetType.TFRecords, mode=RunMode.Validation
),
SourceTrainPath=self.dataset_value(
dataset_type=DatasetType.Directory, mode=RunMode.Trains
),
SourceValidationPath=self.dataset_value(
dataset_type=DatasetType.Directory, mode=RunMode.Validation
),
ValidationSetNum=self.validation_num_val.get(),
SavedSteps=100,
ValidationSteps=500,
EndAcc=self.end_acc_val.get(),
EndCost=self.end_cost_val.get(),
EndEpochs=self.end_epochs_spin.get(),
BatchSize=self.batch_size_val.get(),
ValidationBatchSize=self.validation_batch_size,
LearningRate=self.learning_rate_spin.get(),
DA_Binaryzation=self.data_augmentation_entity.binaryzation,
DA_MedianBlur=self.data_augmentation_entity.median_blur,
DA_GaussianBlur=self.data_augmentation_entity.gaussian_blur,
DA_EqualizeHist=self.data_augmentation_entity.equalize_hist,
DA_Laplace=self.data_augmentation_entity.laplace,
DA_WarpPerspective=self.data_augmentation_entity.warp_perspective,
DA_Rotate=self.data_augmentation_entity.rotate,
DA_PepperNoise=self.data_augmentation_entity.sp_noise,
DA_Brightness=self.data_augmentation_entity.brightness,
DA_Saturation=self.data_augmentation_entity.saturation,
DA_Hue=self.data_augmentation_entity.hue,
DA_Gamma=self.data_augmentation_entity.gamma,
DA_ChannelSwap=self.data_augmentation_entity.channel_swap,
DA_RandomBlank=self.data_augmentation_entity.random_blank,
DA_RandomTransition=self.data_augmentation_entity.random_transition,
DA_RandomCaptcha=self.data_augmentation_entity.random_captcha,
Pre_Binaryzation=self.pretreatment_entity.binaryzation,
Pre_ReplaceTransparent=self.pretreatment_entity.replace_transparent,
Pre_HorizontalStitching=self.pretreatment_entity.horizontal_stitching,
Pre_ConcatFrames=self.pretreatment_entity.concat_frames,
Pre_BlendFrames=self.pretreatment_entity.blend_frames,
Pre_ExecuteMap=self.pretreatment_entity.exec_map
)
model_conf.update()
return model_conf
def make_dataset(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
if self.is_task_running:
messagebox.showerror(
"Error!", "Please terminate the current training first or wait for the training to end."
)
return
self.save_conf()
self.button_state(self.btn_make_dataset, tk.DISABLED)
model_conf = ModelConfig(self.current_project)
train_path = self.dataset_value(DatasetType.Directory, RunMode.Trains)
validation_path = self.dataset_value(DatasetType.Directory, RunMode.Validation)
if len(train_path) < 1:
messagebox.showerror(
"Error!", "{} Sample set has not been added.".format(RunMode.Trains.value)
)
self.button_state(self.btn_make_dataset, tk.NORMAL)
return
self.threading_exec(
lambda: DataSets(model_conf).make_dataset(
trains_path=train_path,
validation_path=validation_path,
is_add=False,
callback=lambda: self.button_state(self.btn_make_dataset, tk.NORMAL),
msg=lambda x: tk.messagebox.showinfo('Make Dataset Status', x)
)
)
@property
def size(self):
return self.json_filter(self.size_val.get(), int)
@property
def image_height(self):
return self.size[1]
@property
def image_width(self):
return self.size[0]
@property
def resize(self):
return self.json_filter(self.resize_val.get(), int)
@property
def neu_cnn(self):
return self.comb_neu_cnn.get()
@property
def neu_recurrent(self):
return self.comb_recurrent.get()
@property
def loss_func(self):
return self.comb_loss.get()
@property
def optimizer(self):
return self.comb_optimizer.get()
@staticmethod
def json_filter(content, item_type):
if not content:
messagebox.showerror(
"Error!", "To select a customized category, you must specify the category set manually."
)
return None
try:
content = json.loads(content)
except ValueError as e:
messagebox.showerror(
"Error!", "Input must be of type JSON."
)
return None
content = [item_type(i) for i in content]
return content
@property
def category(self):
comb_selected = self.comb_category.get()
if not comb_selected:
messagebox.showerror(
"Error!", "Please select built-in category or custom category first"
)
return None
if comb_selected == 'CUSTOMIZED':
category_value = self.category_entry.get()
if category_value == NOT_EDITABLE_MSG:
return self.model_conf.category_param_text
category_value = category_value.replace("'", '"') if "'" in category_value else category_value
category_value = self.json_filter(category_value, str)
else:
category_value = comb_selected
return category_value
def dataset_value(self, dataset_type: DatasetType, mode: RunMode):
listbox = self.sample_map[dataset_type][mode]
value = list(listbox.get(0, listbox.size() - 1))
return value
def compile_task(self):
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
model_conf = ModelConfig(project_name=self.current_project)
if not os.path.exists(model_conf.model_root_path):
messagebox.showerror(
"Error", "Model storage folder does not exist."
)
return
if len(os.listdir(model_conf.model_root_path)) < 3:
messagebox.showerror(
"Error", "There is no training model record, please train before compiling."
)
return
try:
if not self.current_task:
self.current_task = Trains(model_conf)
self.current_task.compile_graph(0)
status = 'Compile completed'
except Exception as e:
messagebox.showerror(
e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
)
status = 'Compile failure'
tk.messagebox.showinfo('Compile Status', status)
def compile(self):
self.job = self.threading_exec(
lambda: self.compile_task()
)
def training_task(self):
model_conf = ModelConfig(project_name=self.current_project)
self.current_task = Trains(model_conf)
try:
self.button_state(self.btn_training, tk.DISABLED)
self.button_state(self.btn_stop, tk.NORMAL)
self.is_task_running = True
self.current_task.train_process()
status = 'Training completed'
except Exception as e:
traceback.print_exc()
messagebox.showerror(
e.__class__.__name__, json.dumps(e.args, ensure_ascii=False)
)
status = 'Training failure'
self.button_state(self.btn_training, tk.NORMAL)
self.button_state(self.btn_stop, tk.DISABLED)
self.comb_project_name['state'] = tk.NORMAL
self.is_task_running = False
tk.messagebox.showinfo('Training Status', status)
@staticmethod
def check_dataset(model_conf):
trains_path = model_conf.trains_path[DatasetType.TFRecords]
validation_path = model_conf.validation_path[DatasetType.TFRecords]
if not trains_path or not validation_path:
messagebox.showerror(
"Error!", "Training set or validation set not defined."
)
return False
for tp in trains_path:
if not os.path.exists(tp):
messagebox.showerror(
"Error!", "Training set path does not exist, please make dataset first"
)
return False
for vp in validation_path:
if not os.path.exists(vp):
messagebox.showerror(
"Error!", "Validation set path does not exist, please make dataset first"
)
return False
return True
def start_training(self):
if not self.check_resize():
return
if not self.current_project:
messagebox.showerror(
"Error!", "Please set the project name first."
)
return
model_conf = self.save_conf()
if not self.check_dataset(model_conf):
return
self.comb_project_name['state'] = tk.DISABLED
self.job = self.threading_exec(
lambda: self.training_task()
)
def stop_training(self):
self.current_task.stop_flag = True
@property
def project_names(self):
return [i.name for i in os.scandir(self.project_root_path) if i.is_dir()]
def fetch_projects(self):
self.comb_project_name['values'] = self.project_names
def browse_dataset(self, dataset_type: DatasetType, mode: RunMode):
if not self.current_project:
messagebox.showerror(
"Error!", "Please define the project name first."
)
return
filename = filedialog.askdirectory()
if not filename:
return
is_sub = False
for i, item in enumerate(os.scandir(filename)):
if item.is_dir():
path = item.path.replace("\\", "/")
if self.sample_map[dataset_type][mode].size() == 0:
self.fetch_sample([path])
self.sample_map[dataset_type][mode].insert(tk.END, path)
if i > 0:
continue
is_sub = True
else:
break
if not is_sub:
filename = filename.replace("\\", "/")
if self.sample_map[dataset_type][mode].size() == 0:
self.fetch_sample([filename])
self.sample_map[dataset_type][mode].insert(tk.END, filename)
@staticmethod
def closest_category(category):
category = set(category)
category_group = dict()
for key in SIMPLE_CATEGORY_MODEL.keys():
category_set = set(category_extract(key))
if category <= category_set:
category_group[key] = len(category_set) - len(category)
if not category_group:
return None
min_index = min(category_group.values())
for k, v in category_group.items():
if v == min_index:
return k
def fetch_category(self):
self.model_conf = self.save_conf()
if self.model_conf.label_from == LabelFrom.TXT or self.label_from_var.get() == LabelFrom.TXT.value:
messagebox.showerror(
"Error!", "The Label From is currently not supported."
)
return
category_list = fetch_category_list(self.model_conf, is_json=True)
if not category_list:
return
self.comb_category.current(0)
if len(category_list) > 1000:
self.category_entry['state'] = tk.DISABLED
self.category_val.set(NOT_EDITABLE_MSG)
self.model_conf.category_param_text = category_list
else:
self.category_entry['state'] = tk.NORMAL
self.category_val.set(category_list)
self.save_conf()
def fetch_sample(self, dataset_path):
file_names = os.listdir(dataset_path[0])[0:100]
category = list()
len_label = -1
for file_name in file_names:
if "_" in file_name:
label = file_name.split("_")[0]
label = [i for i in label]
len_label = len(label)
category.extend(label)
size = PilImage.open(os.path.join(dataset_path[0], file_names[0])).size
self.size_val.set(json.dumps(size))
self.resize_val.set(json.dumps(size))
self.label_num_spin.set(len_label)
if not self.category_val.get() or self.category_val.get() != NOT_EDITABLE_MSG:
category_pram = self.closest_category(category)
if not category_pram:
return
self.comb_category.set(category_pram)
def listbox_delete_item_callback(self, event, listbox: tk.Listbox):
try:
i = listbox.curselection()[0]
listbox.delete(i)
self.save_conf()
except IndexError as e:
print(e)
def comb_category_callback(self, event):
comb_selected = self.comb_category.get()
if comb_selected == 'CUSTOMIZED':
self.category_entry['state'] = tk.NORMAL
else:
self.category_entry.delete(0, tk.END)
self.category_entry['state'] = tk.DISABLED
def check_resize(self):
if self.loss_func == 'CTC':
return True
param = OUTPUT_SHAPE1_MAP[NETWORK_MAP[self.neu_cnn]]
shape1w = math.ceil(1.0*self.resize[0]/param[0])
shape1h = math.ceil(1.0*self.resize[1]/param[0])
input_s1 = shape1w * shape1h * param[1]
label_num = int(self.label_num_spin.get())
if input_s1 % label_num != 0:
messagebox.showerror(
"Error!", "Shape[1] = {} must divide the label_num = {}.".format(input_s1, label_num)
)
return False
return True
@staticmethod
def resource_path(relative_path):
try:
# PyInstaller creates a temp folder and stores path in _MEIPASS
base_path = sys._MEIPASS
except AttributeError:
base_path = os.path.abspath(".")
return os.path.join(base_path, relative_path)
if __name__ == '__main__':
root = tk.Tk()
app = Wizard(root)
root.mainloop()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化