加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 1.13 KB
一键复制 编辑 原始数据 按行查看 历史
李志家 提交于 2020-06-19 14:48 . Initial commit
#coding=utf-8
import tensorflow as tf
import numpy as np
import pdb
from datetime import datetime
from VGG16 import *
import cv2
import os
def test(path):
x = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='input')
keep_prob = tf.placeholder(tf.float32)
output = VGG16(x, keep_prob, 17)
score = tf.nn.softmax(output)
f_cls = tf.argmax(score, 1)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, './model/model.ckpt-9999')
for i in os.listdir(path):
imgpath = os.path.join(path, i)
im = cv2.imread(imgpath)
im = cv2.resize(im, (224 , 224))# * (1. / 255)
im = np.expand_dims(im, axis=0)
#pred = sess.run(f_cls, feed_dict={x:im, keep_prob:1.0})
pred, _score = sess.run([f_cls, score], feed_dict={x:im, keep_prob:1.0})
prob = round(np.max(_score), 4)
#print "{} flowers class is: {}".format(i, pred)
print ("{} flowers class is: {}, score: {}".format(i, int(pred), prob))
sess.close()
if __name__ == '__main__':
path = './test'
test(path)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化