Fetch the repository succeeded.
This action will force synchronization from 肆十二/vegetables_tf2.3, which will overwrite any changes that you have made since you forked the repository, and can not be recovered!!!
Synchronous operation will process in the background and will refresh the page when finishing processing. Please be patient.
# -*- coding: utf-8 -*-
# @Time : 2021/6/17 20:29
# @Author : dejahu
# @Email : 1148392984@qq.com
# @File : test_model.py
# @Software: PyCharm
# @Brief : 模型测试代码,测试会生成热力图,热力图会保存在results目录下
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
# 数据加载,分别从训练的数据集的文件夹和测试的文件夹中加载训练集和验证集
def data_load(data_dir, test_data_dir, img_height, img_width, batch_size):
# 加载训练集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# 加载测试集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
test_data_dir,
label_mode='categorical',
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
# 返回处理之后的训练集、验证集和类名
return train_ds, val_ds, class_names
# 测试mobilenet准确率
def test_mobilenet():
# todo 加载数据, 修改为你自己的数据集的路径
train_ds, test_ds, class_names = data_load("../data/vegetable_fruit/image_data",
"../data/vegetable_fruit/test_image_data", 224, 224, 16)
# todo 加载模型,修改为你的模型名称
model = tf.keras.models.load_model("models/mobilenet_fv.h5")
# model.summary()
# 测试
loss, accuracy = model.evaluate(test_ds)
# 输出结果
print('Mobilenet test accuracy :', accuracy)
test_real_labels = []
test_pre_labels = []
for test_batch_images, test_batch_labels in test_ds:
test_batch_labels = test_batch_labels.numpy()
test_batch_pres = model.predict(test_batch_images)
# print(test_batch_pres)
test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
# print(test_batch_labels_max)
# print(test_batch_pres_max)
# 将推理对应的标签取出
for i in test_batch_labels_max:
test_real_labels.append(i)
for i in test_batch_pres_max:
test_pre_labels.append(i)
# break
# print(test_real_labels)
# print(test_pre_labels)
class_names_length = len(class_names)
heat_maps = np.zeros((class_names_length, class_names_length))
for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1
print(heat_maps)
heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
# print(heat_maps_sum)
print()
heat_maps_float = heat_maps / heat_maps_sum
print(heat_maps_float)
# title, x_labels, y_labels, harvest
show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
save_name="results/heatmap_mobilenet.png")
# 测试cnn模型准确率
def test_cnn():
# todo 加载数据, 修改为你自己的数据集的路径
train_ds, test_ds, class_names = data_load("../data/vegetable_fruit/image_data",
"../data/vegetable_fruit/test_image_data", 224, 224, 16)
# todo 加载模型,修改为你的模型名称
model = tf.keras.models.load_model("models/cnn_fv.h5")
# model.summary()
# 测试
loss, accuracy = model.evaluate(test_ds)
# 输出结果
print('CNN test accuracy :', accuracy)
# 对模型分开进行推理
test_real_labels = []
test_pre_labels = []
for test_batch_images, test_batch_labels in test_ds:
test_batch_labels = test_batch_labels.numpy()
test_batch_pres = model.predict(test_batch_images)
# print(test_batch_pres)
test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
# print(test_batch_labels_max)
# print(test_batch_pres_max)
# 将推理对应的标签取出
for i in test_batch_labels_max:
test_real_labels.append(i)
for i in test_batch_pres_max:
test_pre_labels.append(i)
# break
# print(test_real_labels)
# print(test_pre_labels)
class_names_length = len(class_names)
heat_maps = np.zeros((class_names_length, class_names_length))
for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1
print(heat_maps)
heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
# print(heat_maps_sum)
print()
heat_maps_float = heat_maps / heat_maps_sum
print(heat_maps_float)
# title, x_labels, y_labels, harvest
show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
save_name="results/heatmap_cnn.png")
def show_heatmaps(title, x_labels, y_labels, harvest, save_name):
# 这里是创建一个画布
fig, ax = plt.subplots()
# cmap https://blog.csdn.net/ztf312/article/details/102474190
im = ax.imshow(harvest, cmap="OrRd")
# 这里是修改标签
# We want to show all ticks...
ax.set_xticks(np.arange(len(y_labels)))
ax.set_yticks(np.arange(len(x_labels)))
# ... and label them with the respective list entries
ax.set_xticklabels(y_labels)
ax.set_yticklabels(x_labels)
# 因为x轴的标签太长了,需要旋转一下,更加好看
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# 添加每个热力块的具体数值
# Loop over data dimensions and create text annotations.
for i in range(len(x_labels)):
for j in range(len(y_labels)):
text = ax.text(j, i, round(harvest[i, j], 2),
ha="center", va="center", color="black")
ax.set_xlabel("Predict label")
ax.set_ylabel("Actual label")
ax.set_title(title)
fig.tight_layout()
plt.colorbar(im)
plt.savefig(save_name, dpi=100)
# plt.show()
if __name__ == '__main__':
test_mobilenet()
test_cnn()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。