代码拉取完成,页面将自动刷新
import os
import sys
import argparse
import numpy as np
import time
import cv2
import glob
import onnxruntime
import torch
import torch.nn.functional as F
import math
prj_path = os.path.join(os.path.dirname(__file__), '..')
if prj_path not in sys.path:
sys.path.append(prj_path)
def get_frames(video_name):
"""获取视频帧
Args:
video_name (_type_): _description_
Yields:
_type_: _description_
"""
if not video_name:
rtsp = "rtsp://%s:%s@%s:554/cam/realmonitor?channel=1&subtype=1" % ("admin", "123456", "192.168.1.108")
cap = cv2.VideoCapture(rtsp) if rtsp else cv2.VideoCapture()
# warmup
for i in range(5):
cap.read()
while True:
ret, frame = cap.read()
if ret:
# print('读取成功===>>>', frame.shape)
# yield cv2.resize(frame,(800, 600))
frame
else:
break
elif video_name.endswith('avi') or \
video_name.endswith('mp4'):
cap = cv2.VideoCapture(video_name)
while True:
ret, frame = cap.read()
if ret:
yield frame
else:
break
else:
images = sorted(glob(os.path.join(video_name, 'img', '*.jp*')))
for img in images:
frame = cv2.imread(img)
yield frame
class Preprocessor_wo_mask(object):
def __init__(self):
self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1))
self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1))
def process(self, img_arr: np.ndarray):
# Deal with the image patch
img_tensor = torch.tensor(img_arr).float().permute((2,0,1)).unsqueeze(dim=0)
img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std # (1,3,H,W)
return img_tensor_norm.contiguous()
class MFTrackerORT:
def __init__(self) -> None:
self.debug = True
self.gpu_id = 0
self.providers = ["CUDAExecutionProvider"]
self.provider_options = [{"device_id": str(self.gpu_id)}]
self.model_path = "AutoLabel.onnx"
self.init_track_net()
self.preprocessor = Preprocessor_wo_mask()
self.max_score_decay = 1.0
self.search_factor = 4.5
self.search_size = 224
self.template_factor = 2.0
self.template_size = 112
self.update_interval = 200
self.online_size = 1
def init_track_net(self):
"""使用设置的参数初始化tracker网络
"""
self.ort_session = onnxruntime.InferenceSession(self.model_path, providers=self.providers, provider_options=self.provider_options)
def track_init(self, frame, target_pos=None, target_sz = None):
"""使用第一帧进行初始化
Args:
frame (_type_): _description_
target_pos (_type_, optional): _description_. Defaults to None.
target_sz (_type_, optional): _description_. Defaults to None.
"""
self.trace_list = []
try:
# [x, y, w, h]
init_state = [target_pos[0], target_pos[1], target_sz[0], target_sz[1]]
z_patch_arr, _, z_amask_arr = self.sample_target(frame, init_state, self.template_factor, output_sz=self.template_size)
template = self.preprocessor.process(z_patch_arr)
self.template = template
self.online_template = template
self.online_state = init_state
self.online_image = frame
self.max_pred_score = -1.0
self.online_max_template = template
self.online_forget_id = 0
# save states
self.state = init_state
self.frame_id = 0
print(f"第一帧初始化完毕!")
except:
print(f"第一帧初始化异常!")
exit()
def track(self, image, info: dict = None, ):
H, W, _ = image.shape
self.frame_id += 1
x_patch_arr, resize_factor, x_amask_arr = self.sample_target(image, self.state, self.search_factor,
output_sz=self.search_size) # (x1, y1, w, h)
search = self.preprocessor.process(x_patch_arr)
# compute ONNX Runtime output prediction
ort_inputs = {'img_t': self.to_numpy(self.template), 'img_ot': self.to_numpy(self.online_template), 'img_search': self.to_numpy(search)}
ort_outs = self.ort_session.run(None, ort_inputs)
# print(f">>> lenght trt_outputs: {ort_outs}")
pred_boxes = torch.from_numpy(ort_outs[0])
pred_score = torch.from_numpy(ort_outs[1])
# print(f">>> box and score: {pred_boxes} {pred_score}")
# Baseline: Take the mean of all pred boxes as the final result
pred_box = (pred_boxes.mean(dim=0) * self.search_size / resize_factor).tolist() # (cx, cy, w, h) [0,1]
# get the final box result
self.state = self.clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)
self.max_pred_score = self.max_pred_score * self.max_score_decay
# update template
if pred_score > 0.5 and pred_score > self.max_pred_score:
z_patch_arr, _, z_amask_arr = self.sample_target(image, self.state,
self.template_factor,
output_sz=self.template_size) # (x1, y1, w, h)
self.online_max_template = self.preprocessor.process(z_patch_arr)
self.max_pred_score = pred_score
if self.frame_id % self.update_interval == 0:
if self.online_size == 1:
self.online_template = self.online_max_template
else:
self.online_template[self.online_forget_id:self.online_forget_id+1] = self.online_max_template
self.online_forget_id = (self.online_forget_id + 1) % self.online_size
self.max_pred_score = -1
self.online_max_template = self.template
# for debug
if self.debug:
x1, y1, w, h = self.state
# image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.rectangle(image, (int(x1),int(y1)), (int(x1+w),int(y1+h)), color=(0,0,255), thickness=1)
return {"target_bbox": self.state, "conf_score": pred_score}
def map_box_back(self, pred_box: list, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box
half_side = 0.5 * self.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]
def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
half_side = 0.5 * self.search_size / resize_factor
cx_real = cx + (cx_prev - half_side)
cy_real = cy + (cy_prev - half_side)
return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)
def to_numpy(self, tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def sample_target(self, im, target_bb, search_area_factor, output_sz=None, mask=None):
""" Extracts a square crop centered at target_bb box, of area search_area_factor^2 times target_bb area
args:
im - cv image
target_bb - target box [x, y, w, h]
search_area_factor - Ratio of crop size to target size
output_sz - (float) Size to which the extracted crop is resized (always square). If None, no resizing is done.
returns:
cv image - extracted crop
float - the factor by which the crop has been resized to make the crop size equal output_size
"""
if not isinstance(target_bb, list):
x, y, w, h = target_bb.tolist()
else:
x, y, w, h = target_bb
# Crop image
crop_sz = math.ceil(math.sqrt(w * h) * search_area_factor)
if crop_sz < 1:
raise Exception('Too small bounding box.')
x1 = int(round(x + 0.5 * w - crop_sz * 0.5))
x2 = int(x1 + crop_sz)
y1 = int(round(y + 0.5 * h - crop_sz * 0.5))
y2 = int(y1 + crop_sz)
x1_pad = int(max(0, -x1))
x2_pad = int(max(x2 - im.shape[1] + 1, 0))
y1_pad = int(max(0, -y1))
y2_pad = int(max(y2 - im.shape[0] + 1, 0))
# Crop target
im_crop = im[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad, :]
if mask is not None:
mask_crop = mask[y1 + y1_pad:y2 - y2_pad, x1 + x1_pad:x2 - x2_pad]
# Pad
im_crop_padded = cv2.copyMakeBorder(im_crop, y1_pad, y2_pad, x1_pad, x2_pad, cv2.BORDER_CONSTANT)
# deal with attention mask
H, W, _ = im_crop_padded.shape
att_mask = np.ones((H,W))
end_x, end_y = -x2_pad, -y2_pad
if y2_pad == 0:
end_y = None
if x2_pad == 0:
end_x = None
att_mask[y1_pad:end_y, x1_pad:end_x] = 0
if mask is not None:
mask_crop_padded = F.pad(mask_crop, pad=(x1_pad, x2_pad, y1_pad, y2_pad), mode='constant', value=0)
if output_sz is not None:
resize_factor = output_sz / crop_sz
im_crop_padded = cv2.resize(im_crop_padded, (output_sz, output_sz))
att_mask = cv2.resize(att_mask, (output_sz, output_sz)).astype(np.bool_)
if mask is None:
return im_crop_padded, resize_factor, att_mask
mask_crop_padded = \
F.interpolate(mask_crop_padded[None, None], (output_sz, output_sz), mode='bilinear', align_corners=False)[0, 0]
return im_crop_padded, resize_factor, att_mask, mask_crop_padded
else:
if mask is None:
return im_crop_padded, att_mask.astype(np.bool_), 1.0
return im_crop_padded, 1.0, att_mask.astype(np.bool_), mask_crop_padded
def clip_box(self, box: list, H, W, margin=0):
x1, y1, w, h = box
x2, y2 = x1 + w, y1 + h
x1 = min(max(0, x1), W-margin)
x2 = min(max(margin, x2), W)
y1 = min(max(0, y1), H-margin)
y2 = min(max(margin, y2), H)
w = max(margin, x2-x1)
h = max(margin, y2-y1)
return [x1, y1, w, h]
def run(video_path, output_dir, file_base_name, classes_name="demo"):
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "classes.txt"), "w") as f:
f.write(classes_name)
Tracker = MFTrackerORT()
Tracker.video_name = video_path
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
total_time = 0
isStop = False
isAuto = False
for frame_id, frame in enumerate(get_frames(Tracker.video_name)):
print("1111111111111111111111111111111111 ", frame_id)
# show the first frame
if frame_id == 0:
isStop = True
cv2.imshow("Frame", frame)
cv2.waitKey(0)
key = cv2.waitKey(0) if isStop else cv2.waitKey(1)
if key == ord('a'): # Start continuous tracking
isAuto = True
isStop = False
if key == ord('s'): # Pause tracking
isAuto = False
isStop = True
tic = cv2.getTickCount()
if key == ord("p"):
if frame_id != 0:
cv2.putText(frame, f"Passed [{frame_id}]", (20, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, (0, 255, 0), 2)
cv2.imshow(f"Frame", frame)
print(f"pass {frame_id}")
continue
if key == ord("i"): # Initialize the bounding box
cv2.destroyAllWindows()
x, y, w, h = cv2.selectROI("Init", frame, fromCenter=False)
target_pos = [x, y]
target_sz = [w, h]
Tracker.track_init(frame, target_pos, target_sz)
state = Tracker.track(frame)
cv2.imshow('Tracking', frame)
# cv2.waitKey(0)
continue
# Check if the mode is automatic and 'a' key is not pressed (to avoid unnecessary tracking in manual mode)
if isAuto and key != ord('a'):
cv2.imwrite(os.path.join(output_dir, f"{frame_id}.jpg"), frame)
state = Tracker.track(frame)
# Convert the bbox coordinates to YOLO format
x1, y1, w, h = state["target_bbox"]
img_height, img_width, _ = frame.shape
x_center = x1 + w / 2.0
y_center = y1 + h / 2.0
# Normalize the coordinates to the range of [0, 1]
x_center /= img_width
y_center /= img_height
w /= img_width
h /= img_height
# Convert the coordinates to YOLO format
yolo_bbox = [x_center, y_center, w, h]
# Save the YOLO format to a text file
with open(os.path.join(output_dir, f"{frame_id}.txt"), 'a') as f:
f.write(str(0) + ' ' + ' '.join(map(str, yolo_bbox)) + '\n')
cv2.putText(frame, f"Tracking [{frame_id} | {total_frames}]", (20, 30), cv2.FONT_HERSHEY_SIMPLEX,
1, (255, 0, 0), 1)
cv2.imshow('Tracking', frame)
key
toc = cv2.getTickCount() - tic
toc = int(1 / (toc / cv2.getTickFrequency()))
total_time += toc
cv2.destroyAllWindows()
if __name__ == '__main__':
video_path = "/home/deepseavision/Documents/720p_tracking_dataset/track_failure_case/NORM0030_shift_to_building.mp4"
run(video_path=video_path,
output_dir="/home/deepseavision/Documents/720p_tracking_dataset/track_failure_case/NORM0030_shift_to_building",
file_base_name="building_1")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。