代码拉取完成,页面将自动刷新
from __future__ import print_function
import sys
if len(sys.argv) < 4:
print('Usage:')
print('python eval.py datacfg cfgfile weight1 weight2 ...')
exit()
import time
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
import dataset
import random
import math
from utils import *
from cfg import parse_cfg
from darknet import Darknet
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
trainlist = data_options['train']
testlist = data_options['valid']
gpus = data_options['gpus'] # e.g. 0,1,2,3
num_workers = int(data_options['num_workers'])
batch_size = int(net_options['batch'])
#Train parameters
use_cuda = True
seed = 22222
eps = 1e-5
# Test parameters
conf_thresh = 0.25
nms_thresh = 0.4
iou_thresh = 0.5
###############
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
model = Darknet(cfgfile)
model.print_network()
init_width = model.width
init_height = model.height
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
test_loader = torch.utils.data.DataLoader(
dataset.listDataset(testlist, shape=(init_width, init_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(),
]), train=False),
batch_size=batch_size, shuffle=False, **kwargs)
if use_cuda:
model = torch.nn.DataParallel(model).cuda()
def test():
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
model.eval()
num_classes = model.module.num_classes
anchors = model.module.anchors
num_anchors = model.module.num_anchors
total = 0.0
proposals = 0.0
correct = 0.0
for batch_idx, (data, target) in enumerate(test_loader):
if use_cuda:
data = data.cuda()
data = Variable(data, volatile=True)
output = model(data).data
all_boxes = get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors)
for i in range(output.size(0)):
boxes = all_boxes[i]
boxes = nms(boxes, nms_thresh)
truths = target[i].view(-1, 5)
num_gts = truths_length(truths)
total = total + num_gts
for i in range(len(boxes)):
if boxes[i][4] > conf_thresh:
proposals = proposals+1
for i in range(num_gts):
box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
best_iou = 0
best_j = -1
for j in range(len(boxes)):
iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
if iou > best_iou:
best_j = j
best_iou = iou
if best_iou > iou_thresh and boxes[best_j][6] == box_gt[6]:
correct = correct+1
precision = 1.0*correct/(proposals+eps)
recall = 1.0*correct/(total+eps)
fscore = 2.0*precision*recall/(precision+recall+eps)
logging("precision: %f, recall: %f, fscore: %f" % (precision, recall, fscore))
for i in range(3, len(sys.argv)):
weightfile = sys.argv[i]
model.module.load_weights(weightfile)
logging('evaluating ... %s' % (weightfile))
test()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。