代码拉取完成,页面将自动刷新
from __future__ import print_function, division
import _thread
import os
import cv2
import numpy as np
# from tensorflow.python.keras import Sequential, Input, Model
# from tensorflow.python.keras.layers import Conv2D, LeakyReLU, BatchNormalization, Dropout, Flatten, Dense
# from tensorflow.python.keras.optimizer_v2.adam import Adam
from keras import Input
from keras.engine.saving import load_model
from keras.layers import Dropout, BatchNormalization, LeakyReLU, Dense, Flatten
from keras.layers.convolutional import Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from util.dataloader import dataloader
import matplotlib.pyplot as plt
class CNN:
def __init__(self, using_model=True):
# 数据集图片大小信息
self.img_rows = 100
self.img_cols = 100
self.channels = 1 # 图像通道数,彩色图为3,黑白图为1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.epochs = 2000
self.batch_size = 4
self.save_interval = 10
# 优化器
self.optimizer = Adam(0.0002, 0.5)
self.load_all = True # 是否一次性将全部内容加载入内存进行训练
self.dataloader = dataloader(self.img_shape, load_all=self.load_all)
if using_model:
if os.path.exists("./model/model.h5"):
self.model = load_model("./model/model.h5")
else:
print("No model present there!")
self.model = self.build_model()
else:
self.model = self.build_model()
self.model.summary()
if self.dataloader.multi_category:
self.model.compile(loss='categorical_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
else:
self.model.compile(loss='binary_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
def build_model(self):
'''
构筑卷积神经网络
:return: CNN
'''
cnum = 16
model = Sequential()
model.add(Conv2D(cnum, kernel_size=4, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(cnum * 2, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dropout(0.25))
model.add(Conv2D(cnum * 4, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(1, kernel_size=4, strides=1, padding="same"))
model.add(Dropout(0.25))
model.add(Flatten())
if self.dataloader.multi_category:
model.add(Dense(self.dataloader.num_category, activation='sigmoid'))
else:
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self):
'''
训练模型
:return: None
'''
self.dataloader.load_all_data()
if self.dataloader.load_all:
for epoch in range(int(self.epochs / self.save_interval)):
self.model.fit(self.dataloader.x_data, self.dataloader.y_data, batch_size=self.batch_size,
epochs=self.save_interval)
self.save_model()
else:
self.dataloader.self_load_data(self.batch_size)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
for epoch in range(self.epochs):
_thread.start_new_thread(self.dataloader.self_load_data, (self.batch_size,))
loss = self.model.train_on_batch(x_data, y_data)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
print("%d [D loss: %2f acc: %.2f]" % (epoch, loss[0], loss[1]))
if epoch % self.save_interval == 0:
self.save_model()
def save_model(self):
'''
保存模型
:return: None
'''
print("save model...")
if not os.path.exists("./model"): # 如果路径不存在
os.makedirs("./model")
self.model.save("./model/model.h5")
def predict(self, img):
'''
预测结果
:param img: 经过dataloader处理的图片
:return: 返回label结果
'''
img = np.expand_dims(img, axis=0)
if self.dataloader.multi_category:
return self.dataloader.category[np.argmax(self.model.predict(img)[0])]
else:
state = self.model.predict(img)
if state >= 0.5:
return self.dataloader.category[1]
else:
return self.dataloader.category[0]
def predict_by_numpyImg(self, img, is_plot=True):
'''
使用Numpy图片作为输入进行预测
:param img: numpy图片
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
def predict_by_path(self, path, is_plot=True):
'''
使用图片路径作为输入进行预测
:param img: 图片路径
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img = cv2.imread(path)
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
if __name__ == '__main__':
cnn = CNN(using_model=False) # 如果需要使用旧模型则using_model为True,反之要训练新模型则为False
cnn.dataloader.rename_all_file() # 对于一些含有中文的图片名称可能会导致程序错误,推荐可以用dataloader自动重命名。
cnn.train()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。