代码拉取完成,页面将自动刷新
# coding=utf-8
"""
@header mypredict.py
@abstract
@MyBlog: http://www.kuture.com.cn
@author Created by Kuture on 2021/8/9
@version 1.0.0 2021/8/9 Creation(
Copyright © 2021年 Mr.Li All rights reserved
"""
import argparse
import cv2
import numpy as np
import torch
from PIL import Image
from packnet_sfm.models.model_wrapper import ModelWrapper
from packnet_sfm.datasets.augmentations import resize_image, to_tensor
from packnet_sfm.utils.horovod import hvd_init, rank
from packnet_sfm.utils.config import parse_test_file
from packnet_sfm.utils.depth import viz_inv_depth
class DepthPredictProcessor(object):
def __init__(self, model_file, img_half=True):
# 初始化模型
hvd_init()
config, state_dict = parse_test_file(model_file)
model_wrapper = ModelWrapper(config, load_datasets=False)
model_wrapper.load_state_dict(state_dict)
self.dtype = torch.float16 if img_half else None
# 检测GPU
if torch.cuda.is_available():
model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=self.dtype)
model_wrapper.eval()
self.model_wrapper = model_wrapper
# self.image_shape = config.datasets.augmentation.image_shape
self.image_shape = (160, 320) # 192x640, 192x160, 160x320, 160x160
def predict(self, image):
image = Image.fromarray(image)
# Resize and to tensor
image = resize_image(image, self.image_shape)
image = to_tensor(image).unsqueeze(0)
if torch.cuda.is_available():
image = image.to('cuda:{}'.format(rank()), dtype=self.dtype)
pred_inv_depth = self.model_wrapper.depth(image)['inv_depths'][0] # 获取模型预测结果
rgb = image[0].permute(1, 2, 0).detach().cpu().numpy() * 255 # 生成RGB图
viz_pred_inv_depth = viz_inv_depth(pred_inv_depth[0]) * 255 # 反转深度值
image = np.concatenate([rgb, viz_pred_inv_depth], 0) # 垂直连接RGB与深度信息
# image = image[:, :, ::-1]
# display_img = np.uint8(image)
display_img = np.uint8(viz_pred_inv_depth[:, :, ::-1])
return display_img
def camera_display(self):
# cap = cv2.VideoCapture('/home/kuture/Desktop/test_videos/myRaw3.mp4')
cap = cv2.VideoCapture(0)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # 宽度
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # 高度
while True:
ret_val, frame = cap.read()
if ret_val:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res_img = self.predict(frame)
cv2.imshow('', cv2.resize(res_img, (int(width * 1), int(height * 1))))
if cv2.waitKeyEx(1) == ord('q'):
break
if __name__ == '__main__':
model_file = './Data/PackNet01_MR_velsup_CStoK.ckpt'
# img_path = '/home/kuture/Desktop/002.jpeg'
# image = cv2.imread(img_path)
# print(image.shape)
# image = Image.fromarray(image)
dep_pred = DepthPredictProcessor(model_file)
# dep_pred.predict(image)
dep_pred.camera_display()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。