加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
CNN.py 6.19 KB
一键复制 编辑 原始数据 按行查看 历史
Quix 提交于 2020-06-14 18:50 . first commit
from __future__ import print_function, division
import _thread
import os
import cv2
import numpy as np
# from tensorflow.python.keras import Sequential, Input, Model
# from tensorflow.python.keras.layers import Conv2D, LeakyReLU, BatchNormalization, Dropout, Flatten, Dense
# from tensorflow.python.keras.optimizer_v2.adam import Adam
from keras import Input
from keras.engine.saving import load_model
from keras.layers import Dropout, BatchNormalization, LeakyReLU, Dense, Flatten
from keras.layers.convolutional import Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from util.dataloader import dataloader
import matplotlib.pyplot as plt
class CNN:
def __init__(self, using_model=True):
# 数据集图片大小信息
self.img_rows = 100
self.img_cols = 100
self.channels = 1 # 图像通道数,彩色图为3,黑白图为1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.epochs = 2000
self.batch_size = 4
self.save_interval = 10
# 优化器
self.optimizer = Adam(0.0002, 0.5)
self.load_all = True # 是否一次性将全部内容加载入内存进行训练
self.dataloader = dataloader(self.img_shape, load_all=self.load_all)
if using_model:
if os.path.exists("./model/model.h5"):
self.model = load_model("./model/model.h5")
else:
print("No model present there!")
self.model = self.build_model()
else:
self.model = self.build_model()
self.model.summary()
if self.dataloader.multi_category:
self.model.compile(loss='categorical_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
else:
self.model.compile(loss='binary_crossentropy',
optimizer=self.optimizer,
metrics=['accuracy'])
def build_model(self):
'''
构筑卷积神经网络
:return: CNN
'''
cnum = 16
model = Sequential()
model.add(Conv2D(cnum, kernel_size=4, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(cnum * 2, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dropout(0.25))
model.add(Conv2D(cnum * 4, kernel_size=4, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(1, kernel_size=4, strides=1, padding="same"))
model.add(Dropout(0.25))
model.add(Flatten())
if self.dataloader.multi_category:
model.add(Dense(self.dataloader.num_category, activation='sigmoid'))
else:
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self):
'''
训练模型
:return: None
'''
self.dataloader.load_all_data()
if self.dataloader.load_all:
for epoch in range(int(self.epochs / self.save_interval)):
self.model.fit(self.dataloader.x_data, self.dataloader.y_data, batch_size=self.batch_size,
epochs=self.save_interval)
self.save_model()
else:
self.dataloader.self_load_data(self.batch_size)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
for epoch in range(self.epochs):
_thread.start_new_thread(self.dataloader.self_load_data, (self.batch_size,))
loss = self.model.train_on_batch(x_data, y_data)
x_data = self.dataloader.x_data
y_data = self.dataloader.y_data
print("%d [D loss: %2f acc: %.2f]" % (epoch, loss[0], loss[1]))
if epoch % self.save_interval == 0:
self.save_model()
def save_model(self):
'''
保存模型
:return: None
'''
print("save model...")
if not os.path.exists("./model"): # 如果路径不存在
os.makedirs("./model")
self.model.save("./model/model.h5")
def predict(self, img):
'''
预测结果
:param img: 经过dataloader处理的图片
:return: 返回label结果
'''
img = np.expand_dims(img, axis=0)
if self.dataloader.multi_category:
return self.dataloader.category[np.argmax(self.model.predict(img)[0])]
else:
state = self.model.predict(img)
if state >= 0.5:
return self.dataloader.category[1]
else:
return self.dataloader.category[0]
def predict_by_numpyImg(self, img, is_plot=True):
'''
使用Numpy图片作为输入进行预测
:param img: numpy图片
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
def predict_by_path(self, path, is_plot=True):
'''
使用图片路径作为输入进行预测
:param img: 图片路径
:param is_plot: 是否使用matplotlib显示结果
:return: 返回label结果
'''
img = cv2.imread(path)
img_processed = self.dataloader.process_img(img)
label = self.predict(img_processed)
if is_plot:
plt.imshow(img)
plt.title(label)
plt.show()
return label
if __name__ == '__main__':
cnn = CNN(using_model=False) # 如果需要使用旧模型则using_model为True,反之要训练新模型则为False
cnn.dataloader.rename_all_file() # 对于一些含有中文的图片名称可能会导致程序错误,推荐可以用dataloader自动重命名。
cnn.train()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化