加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
evaluate.py 6.99 KB
一键复制 编辑 原始数据 按行查看 历史
YunYang1994 提交于 2019-05-14 11:53 . I hate tensorflow
#! /usr/bin/env python
# coding=utf-8
#================================================================
# Copyright (C) 2019 * Ltd. All rights reserved.
#
# Editor : VIM
# File name : evaluate.py
# Author : YunYang1994
# Created date: 2019-02-21 15:30:26
# Description :
#
#================================================================
import cv2
import os
import shutil
import numpy as np
import tensorflow as tf
import core.utils as utils
from core.config import cfg
from core.yolov3 import YOLOV3
class YoloTest(object):
def __init__(self):
self.input_size = cfg.TEST.INPUT_SIZE
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
self.num_classes = len(self.classes)
self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS))
self.score_threshold = cfg.TEST.SCORE_THRESHOLD
self.iou_threshold = cfg.TEST.IOU_THRESHOLD
self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY
self.annotation_path = cfg.TEST.ANNOT_PATH
self.weight_file = cfg.TEST.WEIGHT_FILE
self.write_image = cfg.TEST.WRITE_IMAGE
self.write_image_path = cfg.TEST.WRITE_IMAGE_PATH
self.show_label = cfg.TEST.SHOW_LABEL
with tf.name_scope('input'):
self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')
self.trainable = tf.placeholder(dtype=tf.bool, name='trainable')
model = YOLOV3(self.input_data, self.trainable)
self.pred_sbbox, self.pred_mbbox, self.pred_lbbox = model.pred_sbbox, model.pred_mbbox, model.pred_lbbox
with tf.name_scope('ema'):
ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
self.saver = tf.train.Saver(ema_obj.variables_to_restore())
self.saver.restore(self.sess, self.weight_file)
def predict(self, image):
org_image = np.copy(image)
org_h, org_w, _ = org_image.shape
image_data = utils.image_preporcess(image, [self.input_size, self.input_size])
image_data = image_data[np.newaxis, ...]
pred_sbbox, pred_mbbox, pred_lbbox = self.sess.run(
[self.pred_sbbox, self.pred_mbbox, self.pred_lbbox],
feed_dict={
self.input_data: image_data,
self.trainable: False
}
)
pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + self.num_classes)),
np.reshape(pred_mbbox, (-1, 5 + self.num_classes)),
np.reshape(pred_lbbox, (-1, 5 + self.num_classes))], axis=0)
bboxes = utils.postprocess_boxes(pred_bbox, (org_h, org_w), self.input_size, self.score_threshold)
bboxes = utils.nms(bboxes, self.iou_threshold)
return bboxes
def evaluate(self):
predicted_dir_path = './mAP/predicted'
ground_truth_dir_path = './mAP/ground-truth'
if os.path.exists(predicted_dir_path): shutil.rmtree(predicted_dir_path)
if os.path.exists(ground_truth_dir_path): shutil.rmtree(ground_truth_dir_path)
if os.path.exists(self.write_image_path): shutil.rmtree(self.write_image_path)
os.mkdir(predicted_dir_path)
os.mkdir(ground_truth_dir_path)
os.mkdir(self.write_image_path)
with open(self.annotation_path, 'r') as annotation_file:
for num, line in enumerate(annotation_file):
annotation = line.strip().split()
image_path = annotation[0]
image_name = image_path.split('/')[-1]
image = cv2.imread(image_path)
bbox_data_gt = np.array([list(map(int, box.split(','))) for box in annotation[1:]])
if len(bbox_data_gt) == 0:
bboxes_gt=[]
classes_gt=[]
else:
bboxes_gt, classes_gt = bbox_data_gt[:, :4], bbox_data_gt[:, 4]
ground_truth_path = os.path.join(ground_truth_dir_path, str(num) + '.txt')
print('=> ground truth of %s:' % image_name)
num_bbox_gt = len(bboxes_gt)
with open(ground_truth_path, 'w') as f:
for i in range(num_bbox_gt):
class_name = self.classes[classes_gt[i]]
xmin, ymin, xmax, ymax = list(map(str, bboxes_gt[i]))
bbox_mess = ' '.join([class_name, xmin, ymin, xmax, ymax]) + '\n'
f.write(bbox_mess)
print('\t' + str(bbox_mess).strip())
print('=> predict result of %s:' % image_name)
predict_result_path = os.path.join(predicted_dir_path, str(num) + '.txt')
bboxes_pr = self.predict(image)
if self.write_image:
image = utils.draw_bbox(image, bboxes_pr, show_label=self.show_label)
cv2.imwrite(self.write_image_path+image_name, image)
with open(predict_result_path, 'w') as f:
for bbox in bboxes_pr:
coor = np.array(bbox[:4], dtype=np.int32)
score = bbox[4]
class_ind = int(bbox[5])
class_name = self.classes[class_ind]
score = '%.4f' % score
xmin, ymin, xmax, ymax = list(map(str, coor))
bbox_mess = ' '.join([class_name, score, xmin, ymin, xmax, ymax]) + '\n'
f.write(bbox_mess)
print('\t' + str(bbox_mess).strip())
def voc_2012_test(self, voc2012_test_path):
img_inds_file = os.path.join(voc2012_test_path, 'ImageSets', 'Main', 'test.txt')
with open(img_inds_file, 'r') as f:
txt = f.readlines()
image_inds = [line.strip() for line in txt]
results_path = 'results/VOC2012/Main'
if os.path.exists(results_path):
shutil.rmtree(results_path)
os.makedirs(results_path)
for image_ind in image_inds:
image_path = os.path.join(voc2012_test_path, 'JPEGImages', image_ind + '.jpg')
image = cv2.imread(image_path)
print('predict result of %s:' % image_ind)
bboxes_pr = self.predict(image)
for bbox in bboxes_pr:
coor = np.array(bbox[:4], dtype=np.int32)
score = bbox[4]
class_ind = int(bbox[5])
class_name = self.classes[class_ind]
score = '%.4f' % score
xmin, ymin, xmax, ymax = list(map(str, coor))
bbox_mess = ' '.join([image_ind, score, xmin, ymin, xmax, ymax]) + '\n'
with open(os.path.join(results_path, 'comp4_det_test_' + class_name + '.txt'), 'a') as f:
f.write(bbox_mess)
print('\t' + str(bbox_mess).strip())
if __name__ == '__main__': YoloTest().evaluate()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化