加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
detector.py 2.86 KB
一键复制 编辑 原始数据 按行查看 历史
import torch
import numpy as np
from models.experimental import attempt_load
from utils.datasets import letterbox
from utils.general import non_max_suppression, scale_coords
from utils.torch_utils import select_device
class Detector:
def __init__(self, img_size=640, conf=0.01, iou=0.65, model_type='x6', old_init=False):
if old_init:
self.__old_init__()
else:
self.model = torch.hub.load('./weights/', 'yolov5{}'.format(model_type), pretrained=True, source='local')
self.model.conf = conf
self.model.iou = iou
# self.model.classes = [0, 1, 2, 3, 5, 7, 8]
self.img_size = img_size
def predict(self, img):
"""im = cv2.imread(im_path)
"""
results = self.model(img, size=self.img_size)
results.render()
for img in results.imgs:
output_image_frame = img
return output_image_frame, results.pandas().xyxy[0][['xmin', 'ymin', 'xmax', 'ymax', 'name', 'confidence']].values.tolist()
def __old_init__(self):
self.img_size = 640
self.threshold = 0.3
self.stride = 1
self.weights = './weights/yolov5m.pt'
self.device = '0' if torch.cuda.is_available() else 'cpu'
self.device = select_device(self.device)
# model = attempt_load(self.weights, map_location=self.device)
model = torch.hub.load('./weights/', 'yolov5{}'.format(model_type), pretrained=True, source='local')
model.to(self.device).eval()
model.half()
self.m = model
self.names = model.module.names if hasattr(
model, 'module') else model.names
def preprocess(self, img):
img0 = img.copy()
img = letterbox(img, new_shape=self.img_size)[0]
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img)
img = torch.from_numpy(img).to(self.device)
img = img.half()
img /= 255.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
return img0, img
def detect(self, im):
im0, img = self.preprocess(im)
pred = self.m(img, augment=False)[0]
pred = pred.float()
pred = non_max_suppression(pred, self.threshold, 0.4)
boxes = []
for det in pred:
if det is not None and len(det):
det[:, :4] = scale_coords(
img.shape[2:], det[:, :4], im0.shape).round()
for *x, conf, cls_id in det:
lbl = self.names[int(cls_id)]
if lbl not in ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck']:
continue
pass
x1, y1 = int(x[0]), int(x[1])
x2, y2 = int(x[2]), int(x[3])
boxes.append(
(x1, y1, x2, y2, lbl, conf))
return boxes
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化