代码拉取完成,页面将自动刷新
import matplotlib.pyplot as mp
import pickle
from PIL import Image
import tensorflow.keras as keras
from tensorflow.python.keras.layers import *
from tensorflow.python.keras.models import *
from tensorflow.python.keras.optimizers import *
from tensorflow.python.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
import glob, pickle
import random
import time
import tensorflow as tf
import tensorflow.keras.backend as K
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
CAPTCHA_CHARSET = NUMBER # 使用数字字符集生成验证码
CAPTCHA_LEN = 4 # 验证码长度
CAPTCHA_HEIGHT = 60 # 验证码高度
CAPTCHA_WIDTH = 160 # 验证码长度
TRAIN_DATA_DIR = './images/' # 验证码训练数据集路径
import numpy as np
import tensorflow.gfile as gfile
TEST_DATA_DIR = './test/'
# 适配Keras图像数据格式通道
def fit_keras_channels(batch, rows=CAPTCHA_HEIGHT, cols=CAPTCHA_WIDTH):
if K.image_data_format() == 'channel first':
batch = batch.reshape(batch.shape[0], 1, rows, cols)
input_shape = (1, rows, cols)
else:
batch = batch.reshape(batch.shape[0], rows, cols, 1)
input_shape = (rows, cols, 1)
return batch, input_shape
# 灰度化
def rgb2gray(image):
return np.dot(image[..., :3], [0.299, 0.587, 0.114])
# one-hot编码
def text2vec(text, length=CAPTCHA_LEN, charset=CAPTCHA_CHARSET):
text_len = len(text)
# 验证码长度校验
if text_len != length:
raise ValueError(
"输入字符长度为{},与所需验证码长度{}不相符".format(text_len, length))
vec = np.zeros(length * len(charset))
for i in range(length):
vec[charset.index(text[i]) + i * len(charset)] = 1
return vec
# 向量转为字符
def vec2text(vector):
if not isinstance(vector, np.ndarray):
vector = np.asarray(vector)
vector = np.reshape(vector, [CAPTCHA_LEN, -1])
text = ''
for item in vector:
text += CAPTCHA_CHARSET[np.argmax(item)]
return text
class MCDropout(Dropout):
def call(self, inputs):
return super().call(inputs, training=True)
if __name__ == '__main__':
history_file = './history/train_demo/ captcha_adam_binary_crossentropy_bs_100_epochs_10.history'
model_file = './model/train_demo/ captcha_adam_binary_crossentropy_bs_100_epochs_10.h5'
# 读取测试集数据,处理图像和标签
X_test, Y_test = [], []
# 读取测试集数据
for filename in glob.glob(TEST_DATA_DIR + '*.jpg'):
X_test.append(np.array(Image.open(filename)))
Y_test.append(filename.lstrip(TEST_DATA_DIR + '\\').rstrip('.jpg'))
# 处理图像
X_test = np.array(X_test, dtype=np.float32)
X_test = rgb2gray(X_test) / 255
X_test, _ = fit_keras_channels(X_test)
# 处理标签
Y_test = list(Y_test)
for i in range(len(Y_test)):
Y_test[i] = text2vec(Y_test[i])
Y_test = np.asarray(Y_test)
# 加载模型
# model = keras.models.load_model('./model/train_demo/ captcha_categorical_crossentropy_bs_100_epochs_40.h5')
model = keras.models.load_model('./model/cnn_best_vgg_0.99.h5')
# 预测单张样例
print(vec2text(Y_test[22]))
yy = model.predict(X_test[22].reshape(1, 60, 160, 1))
print(vec2text(yy))
# 预测测试集效果
count = 0
for i in range(len(Y_test)):
pred = vec2text(model.predict(X_test[i].reshape(1, 60, 160, 1)))
real = vec2text(Y_test[i])
if (pred == real):
count = count + 1
print("样本数:{},正确数:{},错误数:{},准确率:{}".format(len(Y_test), count, len(Y_test) - count, count / len(Y_test)))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。