加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
eval.py 2.08 KB
一键复制 编辑 原始数据 按行查看 历史
欧红旭 提交于 2020-06-08 14:55 . first commit
import tensorflow as tf
import numpy as np
# 模型目录
CHECKPOINT_DIR = './runs/1581236664/checkpoints'
INCEPTION_MODEL_FILE = 'model/tensorflow_inception_graph.pb'
# inception-v3模型参数
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' # inception-v3模型中代表瓶颈层结果的张量名称
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' # 图像输入张量对应的名称
file_path = "data/photos/ambulance/(1)_a_car.jpg"
# 评估
checkpoint_file = tf.train.latest_checkpoint(CHECKPOINT_DIR)
with tf.Graph().as_default() as graph:
with tf.compat.v1.Session().as_default() as sess:
# 读取训练好的inception-v3模型
with tf.io.gfile.GFile(INCEPTION_MODEL_FILE, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
# 加载inception-v3模型,并返回数据输入张量和瓶颈层输出张量
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(
graph_def,
return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
# 加载元图和变量
saver = tf.compat.v1.train.import_meta_graph('{}.meta'.format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# 通过名字从图中获取输入占位符
input_x = graph.get_operation_by_name(
'BottleneckInputPlaceholder').outputs[0]
# 我们想要评估的tensors
predictions = graph.get_operation_by_name('evaluation/ArgMax').outputs[0]
# 读取数据
image_data = tf.io.gfile.GFile(file_path, 'rb').read()
# 使用inception-v3处理图片获取特征向量
bottleneck_values = sess.run(bottleneck_tensor,
{jpeg_data_tensor: image_data})
# 将四维数组压缩成一维数组,由于全连接层输入时有batch的维度,所以用列表作为输入
bottleneck_values = [np.squeeze(bottleneck_values)]
# 收集预测值
all_predictions = []
all_predictions = sess.run(predictions, {input_x: bottleneck_values})
print(all_predictions)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化