加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 10.12 KB
一键复制 编辑 原始数据 按行查看 历史
from PIL import Image
import tensorflow.keras.backend as K
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.models import *
from tensorflow.python.keras.optimizers import *
from tensorflow.python.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
import tensorflow as tf
import glob,pickle
import random
import time
import numpy as np
import tensorflow.gfile as gfile
import matplotlib.pyplot as mp
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
CAPTCHA_CHARSET = NUMBER # 使用数字字符集生成验证码
CAPTCHA_LEN = 4 # 验证码长度
CAPTCHA_HEIGHT = 60 # 验证码高度
CAPTCHA_WIDTH = 160 # 验证码长度
TRAIN_DATA_DIR = './images/' # 验证码训练数据集路径
TEST_DATA_DIR = './validation/'
BATCH_SIZE = 100 # 每个批次训练样本的数量
EPOCHS = 40 # 模型训练的轮数
LEARN_RATE = 0.1 # 学习率
# OPT = Adam(lr=LEARN_RATE,amsgrad=True) # 采用adam算法进行模型优化
# OPT = 'RMSprop'
# OPT = 'Nadam'
OPT = Nadam(lr=0.0002)
# OPT = SGD(lr=LEARN_RATE, decay=LEARN_RATE/EPOCHS, momentum=0.9, nesterov=True)
# LOSS = 'binary_crossentropy' # 采用二进制交叉熵损失函数,向量的各分量相互独立
LOSS = 'categorical_crossentropy'
# 模型文件存储路径和文件格式
MODEL_DIR = './model/train_demo/'
MODEL_FORMAT = '.h5'
# 训练记录文件存储路径和文件格式
HISTORY_DIR = './history/train_demo/'
HISTORY_FORMAT = '.history'
# 训练日志内容格式
filename_str = "{} captcha_{}_bs_{}_epochs_{}{}"
# # 模型网络结构文件
# MODEL_VIS_FILE = 'captcha_classification.png'
# 模型文件
MODEL_FILE = filename_str.format(MODEL_DIR , LOSS, str(BATCH_SIZE),
str(EPOCHS),MODEL_FORMAT)
# 训练记录文件
HISTORY_FILE = filename_str.format(HISTORY_DIR , LOSS, str(BATCH_SIZE),
str(EPOCHS),HISTORY_FORMAT)
# 灰度化
def rgb2gray(image):
return np.dot(image[...,:3], [0.299,0.587,0.114])
# one-hot编码
def text2vec(text, length=CAPTCHA_LEN, charset=CAPTCHA_CHARSET):
text_len = len(text)
# 验证码长度校验
if text_len != length:
raise ValueError(
"输入字符长度为{},与所需验证码长度{}不相符".format(text_len,length))
vec = np.zeros(length*len(charset))
for i in range(length):
vec[charset.index(text[i])+i*len(charset)] = 1
return vec
# 向量转为字符
def vec2text(vector):
if not isinstance(vector, np.ndarray):
vector = np.asarray(vector)
vector = np.reshape(vector, [CAPTCHA_LEN, -1])
text = ''
for item in vector:
text += CAPTCHA_CHARSET[np.argmax(item)]
return text
# 适配Keras图像数据格式通道
def fit_keras_channels(batch, rows=CAPTCHA_HEIGHT, cols=CAPTCHA_WIDTH):
if K.image_data_format() == 'channel first':
batch = batch.reshape(batch.shape[0],1,rows,cols)
input_shape = (1,rows,cols)
else:
batch = batch.reshape(batch.shape[0],rows,cols,1)
input_shape = (rows,cols,1)
return batch,input_shape
if __name__ == '__main__':
# 读取训练集数据
X_train, Y_train = [],[]
# glob.glob遍历读取'.jpg'文件
filename = []
filename = glob.glob(TRAIN_DATA_DIR + '*.jpg')
random.seed(time.time())
random.shuffle(filename)
for file in filename:
X_train.append(np.array(Image.open(file)))
Y_train.append(file.lstrip(TRAIN_DATA_DIR+'\\').rstrip('.jpg'))
# 预处理训练集图像
# 将X_train格式转为rgb的np.float32型的numpy数组格式
X_train = np.array(X_train, dtype=np.float32)
# 将数据由rgb图转为gray灰度图
X_train = rgb2gray(X_train)
# 数据归一化
X_train = X_train / 255
# 适配Keras数据通道
X_train, input_shape = fit_keras_channels(X_train)
print(X_train.shape, type(X_train))
print(input_shape)
# 处理训练集标签
Y_train = list(Y_train)
for i in range(len(Y_train)):
# print(Y_train[i])
Y_train[i] = text2vec(Y_train[i])
Y_train = np.asarray(Y_train)
print(Y_train.shape, type(Y_train))
# 读取验证集数据,处理图像和标签
X_test,Y_test = [],[]
# 读取验证集数据
filename = []
filename = glob.glob(TEST_DATA_DIR + '*.jpg')
random.seed(time.time())
random.shuffle(filename)
for file in filename:
X_test.append(np.array(Image.open(file)))
Y_test.append(file.lstrip(TEST_DATA_DIR+'\\').rstrip('.jpg'))
# 处理图像
X_test = np.array(X_test, dtype=np.float32)
X_test = rgb2gray(X_test) / 255
X_test,_ = fit_keras_channels(X_test)
# 处理标签
Y_test = list(Y_test)
for i in range(len(Y_test)):
Y_test[i] = text2vec(Y_test[i])
Y_test = np.asarray(Y_test)
print(X_test.shape)
print(Y_test.shape)
# 创建VGG16模型
# 创建输入层
with tf.name_scope('inputs'):
inputs = Input(shape=input_shape, name='inputs')
# 第一轮卷积
with tf.name_scope('con1'):
conv1 = Conv2D(64, (3,3), name='conv1',padding='same', kernel_initializer='he_uniform')(inputs)
bn1 = BatchNormalization()(conv1)
act1 = Activation('relu')(bn1)
# drop1 = Dropout(0.3)(act1)
conv2 = Conv2D(64, (3, 3), name='conv2',padding='same', kernel_initializer='he_uniform')(act1)
bn2 = BatchNormalization()(conv2)
act2 = Activation('relu')(bn2)
pool1 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool1')(act2)
# 第二轮卷积
with tf.name_scope('con2'):
conv3 = Conv2D(128, (3, 3), name='conv3',padding='same', kernel_initializer='he_uniform')(pool1)
bn3 = BatchNormalization()(conv3)
act3 = Activation('relu')(bn3)
# drop2 = Dropout(0.4)(act3)
conv4 = Conv2D(128, (3, 3), name='conv4',padding='same', kernel_initializer='he_uniform')(act3)
bn4 = BatchNormalization()(conv4)
act4 = Activation('relu')(bn4)
pool2 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool2')(act4)
# 第三轮卷积
with tf.name_scope('con3'):
conv5 = Conv2D(256, (3,3), name='conv5',padding='same', kernel_initializer='he_uniform')(pool2)
bn5 = BatchNormalization()(conv5)
act5 = Activation('relu')(bn5)
# drop3 = Dropout(0.4)(act5)
conv6 = Conv2D(256, (3, 3), name='conv6',padding='same', kernel_initializer='he_uniform')(act5)
bn6 = BatchNormalization()(conv6)
act6 = Activation('relu')(bn6)
# drop4 = Dropout(0.4)(act6)
conv7 = Conv2D(256, (3, 3), name='conv7',padding='same', kernel_initializer='he_uniform')(act6)
bn7 = BatchNormalization()(conv7)
act7 = Activation('relu')(bn7)
pool3 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool3')(act7)
# 第四轮卷积
with tf.name_scope('con4'):
conv8 = Conv2D(512, (3, 3), name='conv8',padding='same', kernel_initializer='he_uniform')(pool3)
bn8 = BatchNormalization()(conv8)
act8 = Activation('relu')(bn8)
# drop5 = Dropout(0.4)(act8)
conv9 = Conv2D(512, (3, 3), name='conv9',padding='same', kernel_initializer='he_uniform')(act8)
bn9 = BatchNormalization()(conv9)
act9 = Activation('relu')(bn9)
# drop6 = Dropout(0.4)(act9)
conv10 = Conv2D(512, (3, 3), name='conv10',padding='same', kernel_initializer='he_uniform')(act9)
bn10 = BatchNormalization()(conv10)
act10 = Activation('relu')(bn10)
pool4 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool4')(act10)
# 第五轮卷积
with tf.name_scope('con5'):
conv11 = Conv2D(512, (3, 3), name='conv11',padding='same', kernel_initializer='he_uniform')(pool4)
bn11 = BatchNormalization()(conv11)
act11 = Activation('relu')(bn11)
# drop7 = Dropout(0.4)(act11)
conv12 = Conv2D(512, (3, 3), name='conv12',padding='same', kernel_initializer='he_uniform')(act11)
bn12 = BatchNormalization()(conv12)
act12 = Activation('relu')(bn12)
# drop8 = Dropout(0.4)(act12)
conv13 = Conv2D(512, (3, 3), name='conv13',padding='same', kernel_initializer='he_uniform')(act12)
bn13 = BatchNormalization()(conv13)
act13 = Activation('relu')(bn13)
pool5 = MaxPooling2D(pool_size=(2, 2), padding='same', name='pool5')(act13)
# 全连接层
with tf.name_scope('dense'):
# 将池化后的数据摊平后输入全连接网络
x = Flatten()(pool3)
# Dropout
x = Dropout(0.5)(x)
x1 = Dense(4096)(x)
bnx1 = BatchNormalization()(x1)
actx1 = Activation('relu')(bnx1)
drop9 = Dropout(0.4)(actx1)
x2 = Dense(4096)(drop9)
bnx2 = BatchNormalization()(x2)
x = Activation('relu')(bnx2)
# 创建4个全连接层,区分10类,分别识别4个字符
x = [Dense(10, activation='softmax', name='func%d'%(i+1))(x) for i in range(4)]
# 输出层
with tf.name_scope('outputs'):
# 将生成的4个字符拼接输出
outs = Concatenate()(x)
# 定义模型的输入和输出
model = Model(inputs=inputs, outputs=outs)
model.compile(optimizer=OPT, loss=LOSS, metrics=['accuracy'])
model.summary()
# 加载训练再训练
# model.load_weights('./model/train_demo/ captcha_categorical_crossentropy_bs_100_epochs_200.h5')
# callbacks = [ModelCheckpoint('./model/cnn_best_vgg.h5', save_best_only=True)]
# 模型训练的过程函数赋值给history
history = model.fit(X_train,Y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,verbose=2,
validation_data=(X_test,Y_test)
)
# 预测样例
print(vec2text(Y_test[22]))
yy = model.predict(X_test[22].reshape(1, 60, 160, 1))
print(vec2text(yy))
if not gfile.Exists(MODEL_DIR):
gfile.MakeDirs(MODEL_DIR)
# 保存模型
model.save(MODEL_FILE)
print(MODEL_FILE)
# 保存模型历史记录
if not gfile.Exists(HISTORY_DIR):
gfile.MakeDirs(HISTORY_DIR)
with open(HISTORY_FILE, 'wb') as f:
pickle.dump(history.history, f)
print(HISTORY_FILE)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化