当前仓库属于关闭状态,部分功能使用受限,详情请查阅 仓库状态说明
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test.py 5.11 KB
一键复制 编辑 原始数据 按行查看 历史
万孝国 提交于 2020-04-17 13:41 . 修改版本6
### 进行错误诊断测试
def test(img_path):
crop_size = 128
resize_size = 128
img = Image.open(img_path)
# print(img_path)
# 统一图片大小
img = img.resize((resize_size, resize_size), Image.ANTIALIAS)
# 随机水平翻转
# r1 = random.random()
# if r1 > 0.5:
# img = img.transpose(Image.FLIP_LEFT_RIGHT)
# # 随机垂直翻转
# r2 = random.random()
# if r2 > 0.5:
# img = img.transpose(Image.FLIP_TOP_BOTTOM)
# # 随机角度翻转
# r3 = random.randint(-3, 3)
# img = img.rotate(r3, expand=False)
# # # 随机裁剪
# # r4 = random.randint(0, int(resize_size - crop_size))
# # r5 = random.randint(0, int(resize_size - crop_size))
# # box = (r4, r5, r4 + crop_size, r5 + crop_size)
# # img = img.crop(box)
# # 把图片转换成numpy值
# img = np.array(img).astype(np.float32)
# # 转换成CHW
# img = img.transpose((2, 0, 1))
# # 转换成BGR
# img = img[(2, 1, 0), :, :] / 255.0
return img
def load_image(file):
#打开图片
im = Image.open(file)
#将图片调整为跟训练数据一样的大小 32*32, 设定ANTIALIAS,即抗锯齿.resize是缩放
im = im.resize((32, 32), Image.ANTIALIAS)
#建立图片矩阵 类型为float32
im = np.array(im.convert("RGB")).astype(np.float32)
#矩阵转置
im = im.transpose((2, 0, 1))
#将像素值从【0-255】转换为【0-1】
im = im / 255.0
#print(im)
im = np.expand_dims(im, axis=0)
# 保持和之前输入image维度一致
print('im_shape的维度:',im.shape)
return im
bb = r'cat_12_train/tO6cKGH8uPEayzmeZJ51Fdr2Tx3fBYSn.jpg'
aa = r'cat_12_train/hwQDH3VBabeFXISfjlWEmYicoyr6qK1p.jpg'
load_image(aa)
img = Image.open(bb)
plt.imshow(img)
plt.show()
load_image(bb)
# if not os.path.exists(bb):
# print('不存在')
# else:
# img = Image.open(bb)
# plt.imshow(img)
# plt.show()
# img = test(bb)
# plt.imshow(img)
# plt.show()
# load_image(bb)
def test2(sample):
img = sample
try:
img = paddle.dataset.image.load_image(file=img, is_color=True)
plt.imshow(img)
plt.show()
img = paddle.dataset.image.simple_transform(im=img, resize_size=resize_size, crop_size=crop_size, is_color=True, is_train=True)
img = img.flatten().astype('float32') / 255.0
return img
except Exception as err:
print(sample)
print(err)
bb = r'cat_12_train/tO6cKGH8uPEayzmeZJ51Fdr2Tx3fBYSn.jpg'
test2(bb);
### 预测测试
# test_infer_path = []
# test_infer_path.append(r'cat_12_train/hwQDH3VBabeFXISfjlWEmYicoyr6qK1p.jpg')
# test_infer_path.append(r'cat_12_train/u9RFI6LdD7xi8YrKTtfNBE2qX0ZUsheg.jpg')
# test_infer_path.append(r'cat_12_train/bEkVJ6cdRWyXjl842umngODsBfhTIUte.jpg')
# test_infer_path.append(r'cat_12_train/BjWieGnrq0zpFaAmNyTvLfc47hOKRgSM.jpg')
# test_infer_path.append(r'cat_12_train/Q8mkFIpAR9qMzJEaijPlOvcNGWBf425o.jpg')
# test_infer_path.append(r'cat_12_train/N5p3cFKAMiseqtZDUayRPkX60YuHdoT9.jpg')
# test_infer_path.append(r'cat_12_train/2xvtgnXlHws7r3LQZ9PoNi6qJ51OW0ed.jpg')
# test_infer_path.append(r'cat_12_train/3ITPOakWJShC2UtZeLNoGn8fbdRxDHjw.jpg')
# test_infer_path.append(r'cat_12_train/qNy8KsJC9GiO6rLmbgpneAE2tjcHVhfZ.jpg')
# test_infer_path.append(r'cat_12_train/0iYHbBF5vXlG7u3DC6p9mc2j1xItRWZS.jpg')
# test_infer_path.append(r'cat_12_train/kBQV7YS3rgW2oJEesimD5Pnx0hpuZRH9.jpg')
# test_infer_path.append(r'cat_12_train/yvQ8O27RFKHrcZu93TgYEJA4hGqbenSX.jpg')
test_infer_path = []
label_infer = []
with open('','r') as f:
for line in f.readlines():
img_path,label = line.split('\t')
test_infer_path.append(img_path)
label_infer.append(label)
with fluid.scope_guard(inference_scope):
#从指定目录中加载 推理model(inference model)
[inference_program, # 预测用的program
feed_target_names, # 是一个str列表,它包含需要在推理 Program 中提供数据的变量的名称。
fetch_targets] = fluid.io.load_inference_model(model_save_dir,#fetch_targets:是一个 Variable 列表,从中我们可以得到推断结果。
infer_exe) #infer_exe: 运行 inference model的 executor
count = 0
for infer_path in test_infer_path:
img = load_image(infer_path)
results = infer_exe.run(inference_program, #运行预测程序
feed={feed_target_names[0]: img}, #喂入要预测的img
fetch_list=fetch_targets) #得到推测结果
data_imgs = os.listdir('cat_12_test')
# print(data_imgs)
img_label = {}
with open('test.csv','r') as f:
lines = f.readlines()
for line in lines:
img , label = line.split(',')
img_label[img] = label
print(len(img_label))
key = img_label.keys()
for i in data_imgs:
if i not in key:
print(i)
with open('testout.csv','w') as f:
for i in data_imgs:
f.write(i+','+img_label[i])
print('修改完成')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化