加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
paddle_ocr_rknn.py 20.29 KB
一键复制 编辑 原始数据 按行查看 历史
yinghan 提交于 2022-09-27 23:15 . paddleocr2rknn
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
import os
import sys
import cv2
import time
import onnx
import math
import copy
import onnxruntime
import numpy as np
import pyclipper
from shapely.geometry import Polygon
from rknn.api import RKNN
# PalldeOCR 检测模块 需要用到的图片预处理类
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, _ = img.shape
# limit the max side
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w]
### 检测结果后处理过程(得到检测框)
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes})
return boxes_batch
## 根据推理结果解码识别结果
class process_pred(object):
def __init__(self, character_dict_path=None, character_type='ch', use_space_char=False):
self.character_str = ''
with open(character_dict_path, 'rb') as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip('\n').strip('\r\n')
self.character_str += line
if use_space_char:
self.character_str += ' '
dict_character = list(self.character_str)
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
result_list = []
ignored_tokens = [0]
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None):
if not isinstance(preds, np.ndarray):
preds = np.array(preds)
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def get_rknn(model_path):
rknn = RKNN()
# load RKNN model
print('--> Load RKNN model: {}'.format(model_path))
ret = rknn.load_rknn(path=model_path)
if ret != 0:
print("load RKNN model: {} failed.".format(model_path))
exit(-1)
print('--> done')
# init RKNN Runtime
print('--> Init RKNN Runtime.')
ret = rknn.init_runtime()
if ret != 0:
print("Init RKNN Runtime failed.")
exit(-1)
print('--> done')
return rknn
class det_rec_functions(object):
def __init__(self, image, use_large=False):
self.img = image.copy()
self.det_rknn_file = r'./rknn_weights/det_time_sim.rknn'
self.small_rec_rknn_file = r'./rknn_weights/rec_time_sim.rknn'
self.onet_det_session = get_rknn(self.det_rknn_file)
self.onet_rec_session = get_rknn(self.small_rec_rknn_file)
self.infer_before_process_op, self.det_re_process_op = self.get_process()
self.postprocess_op = process_pred('./ppocr_keys_v1.txt', 'ch', True)
## 图片预处理过程
def transform(self, data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(self, op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
### 检测框的后处理
def order_points_clockwise(self, pts):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
# sort the points based on their x-coordinates
"""
xSorted = pts[np.argsort(pts[:, 0]), :]
# grab the left-most and right-most points from the sorted
# x-roodinate points
leftMost = xSorted[:2, :]
rightMost = xSorted[2:, :]
# now, sort the left-most coordinates according to their
# y-coordinates so we can grab the top-left and bottom-left
# points, respectively
leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
(tl, bl) = leftMost
rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
(tr, br) = rightMost
rect = np.array([tl, tr, br, bl], dtype="float32")
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
### 定义图片前处理过程,和检测结果后处理过程
def get_process(self):
det_db_thresh = 0.3
det_db_box_thresh = 0.5
max_candidates = 2000
unclip_ratio = 1.6
use_dilation = True
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 2500,
'limit_type': 'max'
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
infer_before_process_op = self.create_operators(pre_process_list)
det_re_process_op = DBPostProcess(det_db_thresh, det_db_box_thresh, max_candidates, unclip_ratio, use_dilation)
return infer_before_process_op, det_re_process_op
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
### 图像输入预处理
def resize_norm_img(self, img, max_wh_ratio):
x = 377
imgC, imgH, imgW = [int(v) for v in "3,32,377".split(",")]
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
resized_w = x
if math.ceil(imgH * ratio) > imgW:
resized_w = x
# else:
# resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (x, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, x), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
## 推理检测图片中的部分
def get_boxes(self):
img_ori = self.img
img_part = img_ori.copy()
data_part = {'image': img_part}
data_part = self.transform(data_part, self.infer_before_process_op)
img_part, shape_part_list = data_part
img_part = np.expand_dims(img_part, axis=0)
shape_part_list = np.expand_dims(shape_part_list, axis=0)
# inputs_part = {self.onet_det_session.get_inputs()[0].name: img_part}
# outs_part = self.onet_det_session.run(None, inputs_part)
img_part = np.transpose(img_part, (0, 2, 3, 1))
np.save('det_time_in0.npy', img_part)
outs_part = self.onet_det_session.inference(inputs=[img_part])
post_res_part = self.det_re_process_op(outs_part[0], shape_part_list)
dt_boxes_part = post_res_part[0]['points']
dt_boxes_part = self.filter_tag_det_res(dt_boxes_part, img_ori.shape)
dt_boxes_part = self.sorted_boxes(dt_boxes_part)
return dt_boxes_part
### 根据bounding box得到单元格图片
def get_rotate_crop_image(self, img, points):
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
### 单张图片推理
def get_img_res(self, onnx_model, img, process_op):
h, w = img.shape[:2]
img = self.resize_norm_img(img, w * 1.0 / h)
img = img[np.newaxis, :]
print(img.shape)
# inputs = {onnx_model.get_inputs()[0].name: img}
# outs = onnx_model.run(None, inputs)
img = np.transpose(img, (0, 2, 3, 1))
np.save('rec_time_in0.npy', img)
outs = onnx_model.inference(inputs=[img])
print('release rec model.')
# onnx_model.release()
result = process_op(outs[0])
return result
def recognition_img(self, dt_boxes):
img_ori = self.img
img = img_ori.copy()
### 识别过程
## 根据bndbox得到小图片
img_list = []
for box in dt_boxes:
tmp_box = copy.deepcopy(box)
img_crop = self.get_rotate_crop_image(img, tmp_box)
img_list.append(img_crop)
## 识别小图片
results = []
results_info = []
for pic in img_list:
res = self.get_img_res(self.onet_rec_session, pic, self.postprocess_op)
results.append(res[0])
results_info.append(res)
return results, results_info
if __name__=='__main__':
# 读取图片
image = cv2.imread('test_imgs/1.jpg')
# OCR-检测-识别
image = cv2.resize(image,(512,64))
ocr_sys = det_rec_functions(image)
# 得到检测框.
dt_boxes = ocr_sys.get_boxes()
# 识别 results: 单纯的识别结果,results_info: 识别结果+置信度
results, results_info = ocr_sys.recognition_img(dt_boxes)
print(results_info)
ocr_sys.onet_det_session.release()
ocr_sys.onet_rec_session.release()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化