代码拉取完成,页面将自动刷新
from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2
from util import *
from darknet import Darknet
from preprocess import prep_image, inp_to_image
import pandas as pd
import random
import argparse
import pickle as pkl
import ft2
def get_test_input(input_dim, CUDA):
img = cv2.imread("imgs/messi.jpg")
img = cv2.resize(img, (input_dim, input_dim))
img_ = img[:,:,::-1].transpose((2,0,1))
img_ = img_[np.newaxis,:,:,:]/255.0
img_ = torch.from_numpy(img_).float()
img_ = Variable(img_)
if CUDA:
img_ = img_.cuda()
return img_
def prep_image(img, inp_dim):
"""
Prepare image for inputting to the neural network.
Returns a Variable
"""
orig_im = img
dim = orig_im.shape[1], orig_im.shape[0]
img = cv2.resize(orig_im, (inp_dim, inp_dim))
img_ = img[:,:,::-1].transpose((2,0,1)).copy()
img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0)
return img_, orig_im, dim
def write(x, img):
c1 = tuple(x[1:3].int())
c2 = tuple(x[3:5].int())
cls = int(x[-1])
conf = str(x[-2].data)
label = "{} {}%".format(classes[cls], float(conf[7:11])*100)
color = random.choice(colors)
cv2.rectangle(img, c1, c2,color, 1)
t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0]
#c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4
#cv2.rectangle(img, c1, c2,color, -1)
#cv2.rectangle(img, c1, c2, color, -1)
ft = ft2.put_chinese_text('msyh.ttf')
ft.draw_text(image=img, pos=(c1[0], c1[1] + t_size[1] - 7), text=label, text_size=15, text_color=[255, 255, 255])
#cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 2, [225,255,255], 1);
return img
def arg_parse():
"""
Parse arguements to the detect module
"""
parser = argparse.ArgumentParser(description='YOLO v3 Cam Demo')
parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.25)
parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
parser.add_argument("--reso", dest = 'reso', help =
"Input resolution of the network. Increase to increase accuracy. Decrease to increase speed",
default = "160", type = str)
return parser.parse_args()
if __name__ == '__main__':
cfgfile = "cfg/yolov3.cfg"
weightsfile = "yolov3.weights"
num_classes = 80
args = arg_parse()
confidence = float(args.confidence)
nms_thesh = float(args.nms_thresh)
start = 0
CUDA = torch.cuda.is_available()
num_classes = 80
bbox_attrs = 5 + num_classes
model = Darknet(cfgfile)
model.load_weights(weightsfile)
model.net_info["height"] = args.reso
inp_dim = int(model.net_info["height"])
assert inp_dim % 32 == 0
assert inp_dim > 32
if CUDA:
model.cuda()
model.eval()
videofile = 'video.avi'
cameraIp = "192.168.0.105"
username = "admin"
password = ""
#cap = cv2.VideoCapture("http://"+cameraIp+"/video.cgi?user=%s&pwd=%s&"%(username,password))
cap = cv2.VideoCapture(0)
#print(cap.isOpened())
assert cap.isOpened(), 'Cannot capture source'
frames = 0
start = time.time()
video_path = "cn_caption.avi"
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
out = cv2.VideoWriter(video_path, fourcc, float(18), (640,480), True)
while cap.isOpened():
ret, frame = cap.read()
if ret:
img, orig_im, dim = prep_image(frame, inp_dim)
# im_dim = torch.FloatTensor(dim).repeat(1,2)
if CUDA:
#im_dim = im_dim.cuda()
img = img.cuda()
output = model(Variable(img), CUDA)
output = write_results(output, confidence, num_classes, nms = True, nms_conf = nms_thesh)
#print(output)
if type(output) == int:
frames += 1
print("FPS of the video is {:5.2f}".format( frames / (time.time() - start)))
cv2.imshow("frame", orig_im)
key = cv2.waitKey(1)
if key & 0xFF == ord('q'):
break
continue
output[:,1:5] = torch.clamp(output[:,1:5], 0.0, float(inp_dim))/inp_dim
# im_dim = im_dim.repeat(output.size(0), 1)
output[:,[1,3]] *= frame.shape[1]
output[:,[2,4]] *= frame.shape[0]
classes = load_classes('data/coco.names')
colors = pkl.load(open("pallete", "rb"))
list(map(lambda x: write(x, orig_im), output))
print(frame.shape)
out.write(frame)
cv2.imshow("frame", orig_im)
key = cv2.waitKey(1)
if key & 0xFF == ord('q'):
break
frames += 1
print("FPS of the video is {:5.2f}".format( frames / (time.time() - start)))
else:
break
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。