convert-the-model-to-rknn/017能用的PaddleOCR_rknn脚本.py

559 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
RK3588 PaddleOCR RKNN推理程序
使用转换后的RKNN模型在RK3588上进行OCR文本检测和识别
"""
import cv2
import yaml
import numpy as np
import math
from rknn.api import RKNN
import argparse
import os
import time
class RK3588OCR:
def __init__(self, det_model_path, rec_model_path):
"""
初始化OCR推理器
Args:
det_model_path: 文本检测RKNN模型路径
rec_model_path: 文本识别RKNN模型路径
"""
self.det_model_path = det_model_path
self.rec_model_path = rec_model_path
# 初始化RKNN实例
self.det_rknn = RKNN(verbose=False)
self.rec_rknn = RKNN(verbose=False)
# 模型输入尺寸
self.det_input_size = (640, 640)
self.rec_input_size = (320, 48)
# 文本检测相关参数
self.det_threshold = 0.3
self.det_box_threshold = 0.6
self.det_unclip_ratio = 1.5
# 加载模型
self._load_models()
self.character = self.get_dict()
def get_dict(self, dict_path='/home/orangepi/Desktop/kangda_robotic_dog/机器狗后台服务/dict.yaml'):
"""
加载字典
"""
with open(dict_path, 'r', encoding='utf-8') as f:
dict_rec = yaml.safe_load(f)
return dict_rec.get('character_dict', [])
def _load_models(self):
"""加载RKNN模型"""
print("加载文本检测模型...")
ret = self.det_rknn.load_rknn(self.det_model_path)
if ret != 0:
raise Exception(f"加载检测模型失败: {ret}")
ret = self.det_rknn.init_runtime(target='rk3588')
if ret != 0:
raise Exception(f"初始化检测模型运行环境失败: {ret}")
print("加载文本识别模型...")
ret = self.rec_rknn.load_rknn(self.rec_model_path)
if ret != 0:
raise Exception(f"加载识别模型失败: {ret}")
ret = self.rec_rknn.init_runtime(target='rk3588')
if ret != 0:
raise Exception(f"初始化识别模型运行环境失败: {ret}")
print("模型加载完成!")
def resize_norm_img_det(self, img, input_shape=(640, 640)):
"""
检测模型的图像预处理 - 固定输入形状 [1, 3, 640, 640]
"""
h, w, _ = img.shape
target_h, target_w = input_shape
# 计算缩放比例 - 保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_img = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用灰色填充
padded_img = np.ones((target_h, target_w, 3), dtype=np.float32) * 114.0 # 直接用float32
# 计算居中位置
top = (target_h - new_h) // 2
left = (target_w - new_w) // 2
# 将缩放后的图像放到居中位置
padded_img[top:top+new_h, left:left+new_w] = resized_img.astype(np.float32)
# 归一化
img = (padded_img / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
img = img.transpose(2, 0, 1).astype(np.float32)
img = np.expand_dims(img, axis=0).astype(np.float32)
return img, ratio, (top, left)
def post_process_det(self, dt_boxes, ratio, padding_info, ori_shape):
"""
检测结果后处理 - 适配固定输入形状
"""
if dt_boxes is None:
return None
ori_h, ori_w = ori_shape
top, left = padding_info
# 将坐标从模型输出空间转换回原图空间
dt_boxes[:, :, 0] = (dt_boxes[:, :, 0] - left) / ratio
dt_boxes[:, :, 1] = (dt_boxes[:, :, 1] - top) / ratio
# 裁剪到原图范围内
dt_boxes[:, :, 0] = np.clip(dt_boxes[:, :, 0], 0, ori_w)
dt_boxes[:, :, 1] = np.clip(dt_boxes[:, :, 1], 0, ori_h)
return dt_boxes
def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height, max_candidates=1000, box_thresh=0.6):
"""
从位图中提取文本框
"""
bitmap = bitmap.astype(np.uint8)
height, width = bitmap.shape
# 查找轮廓
contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
num_contours = min(len(contours), max_candidates)
boxes = []
scores = []
for i in range(num_contours):
contour = contours[i]
points, sside = self.get_mini_boxes(contour)
if sside < 5:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if box_thresh > score:
continue
# 扩展box
box = self.unclip(points, 1.5).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < 5 + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(box[:, 0] / width * dest_width, 0, dest_width)
box[:, 1] = np.clip(box[:, 1] / height * dest_height, 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes), scores
def get_mini_boxes(self, contour):
"""获取最小外接矩形"""
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""快速计算box得分"""
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def unclip(self, box, unclip_ratio):
"""扩展文本框"""
from shapely.geometry import Polygon
import pyclipper
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if len(expanded) == 0:
return box
else:
return np.array(expanded[0])
def resize_norm_img_rec(self, img, input_shape=(320, 48)):
"""
识别模型的图像预处理 - 固定输入形状 [1, 3, 48, 320]
"""
target_w, target_h = input_shape # 注意:宽度在前
h, w = img.shape[:2]
# 计算缩放比例,保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_image = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用黑色填充
padded_image = np.zeros((target_h, target_w, 3), dtype=np.float32) # 直接用float32
# 将缩放后的图像放到左上角(识别模型通常左对齐)
padded_image[:new_h, :new_w] = resized_image.astype(np.float32)
# 归一化
# padded_image = (padded_image / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
# 不缩放反而会将识别结果再移后一个??
padded_image = (padded_image / 255.0).astype(np.float32)
padded_image = padded_image.transpose((2, 0, 1)).astype(np.float32)
return np.expand_dims(padded_image, axis=0).astype(np.float32)
def decode_rec_result(self, preds_prob):
"""
解码识别结果
"""
# preds_idx = preds_idx[0]
preds_prob = preds_prob[0]
preds_idx = np.argmax(preds_prob, axis=1)
preds_prob = np.max(preds_prob, axis=1)
# CTC解码
last_idx = 0
preds_text = []
preds_conf = []
# print("preds_id", len(preds_idx[0]))
for i, idx in enumerate(preds_idx):
if idx != last_idx and idx != 0: # 0是blank
if idx < len(self.character):
# print("self.character[idx]", self.character[idx])
# print("preds_prob[i]", preds_prob[i])
preds_text.append(self.character[idx])
preds_conf.append(preds_prob[i])
last_idx = idx
text = ''.join(preds_text)
conf = np.mean(preds_conf) if preds_conf else 0.0
return text, conf
def detect_text(self, image):
"""
文本检测 - 适配固定输入形状 [1, 3, 640, 640]
"""
ori_h, ori_w = image.shape[:2]
# 预处理
det_img, ratio, padding_info = self.resize_norm_img_det(image)
# 推理
det_output = self.det_rknn.inference(inputs=[det_img], data_format="nchw")[0]
# 后处理
mask = det_output[0, 0, :, :]
threshold = 0.3
bitmap = (mask > threshold).astype(np.uint8) * 255
# 从位图中提取文本框坐标是在640x640空间中的
boxes, scores = self.boxes_from_bitmap(mask, bitmap, 640, 640)
# 将坐标转换回原图空间
if len(boxes) > 0:
boxes = self.post_process_det(boxes, ratio, padding_info, (ori_h, ori_w))
print("*"*100, len(boxes))
return boxes, scores
def visualize_det_results(self, image_path, boxes):
image = cv2.imread(image_path)
for box in boxes:
box = np.array(box, dtype=np.int32)
cv2.polylines(image, [box], True, (0, 255, 0), 2)
cv2.imwrite('./visual_det.jpg', image)
def recognize_text(self, image):
"""
文本识别
"""
# 预处理
rec_img = self.resize_norm_img_rec(image)
# 推理
rec_output = self.rec_rknn.inference(inputs=[rec_img], data_format="nchw")
# 解码
text, conf = self.decode_rec_result(rec_output[0])
# print("")
return text, conf
def get_rotate_crop_image(self, img, points):
"""
根据四个点坐标裁剪并矫正图像
"""
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def ocr(self, image_path):
"""
完整的OCR流程
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
return []
# 1. 文本检测
dt_boxes, scores = self.detect_text(image)
# 可视化检测框
self.visualize_det_results(image_path, dt_boxes)
if dt_boxes is None or len(dt_boxes) == 0:
return []
# 2. 文本识别
ocr_results = []
text_list = []
confidence_list = []
for i, box in enumerate(dt_boxes):
# 裁剪文本区域
box_points = box.astype(np.float32)
crop_img = self.get_rotate_crop_image(image, box_points)
# 识别文本
text, conf = self.recognize_text(crop_img)
if conf > 0.4: # 置信度过滤
ocr_results.append({
'text': text,
'confidence': conf,
'box': box.tolist(),
'score': scores[i] if i < len(scores) else 0.0
})
text_list.append(text)
confidence_list.append(round(conf.item(), 2))
# return ocr_results
return [text_list, confidence_list]
def release(self):
"""释放资源"""
self.det_rknn.release()
self.rec_rknn.release()
def main():
parser = argparse.ArgumentParser(description='RK3588 PaddleOCR RKNN推理')
parser.add_argument('--det_model', type=str, required=True, help='文本检测RKNN模型路径')
parser.add_argument('--rec_model', type=str, required=True, help='文本识别RKNN模型路径')
parser.add_argument('--image', type=str, help='输入图像路径')
parser.add_argument('--video', type=str, help='输入视频路径')
parser.add_argument('--camera', type=int, help='摄像头设备ID')
parser.add_argument('--output', type=str, help='输出路径')
parser.add_argument('--show', action='store_true', help='显示结果')
args = parser.parse_args()
# 检查模型文件
if not os.path.exists(args.det_model):
print(f"检测模型文件不存在: {args.det_model}")
return
if not os.path.exists(args.rec_model):
print(f"识别模型文件不存在: {args.rec_model}")
return
# 初始化OCR
print("初始化RK3588 OCR...")
ocr = RK3588OCR(args.det_model, args.rec_model)
try:
if args.image:
# 图像模式
print(f"处理图像: {args.image}")
# 进行OCR
start_time = time.time()
text, confidence = ocr.ocr(args.image)
print("text", text)
print("confidence", confidence)
total_time = time.time() - start_time
# 打印结果
print(f"\n总耗时: {total_time:.3f}s")
print(f"识别结果:")
for i in range(len(text)):
print(f"{i+1}. 文本: '{text[i]}', 置信度: {confidence[i]:.3f}")
# 绘制结果
# annotated_image = ocr.draw_results(image, results)
# 保存或显示结果
# if args.output:
# cv2.imwrite(args.output, annotated_image)
# print(f"结果已保存到: {args.output}")
# if args.show:
# cv2.imshow('OCR结果', annotated_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
elif args.video or args.camera is not None:
# 视频或摄像头模式
if args.video:
cap = cv2.VideoCapture(args.video)
print(f"处理视频: {args.video}")
else:
cap = cv2.VideoCapture(args.camera)
print(f"使用摄像头: {args.camera}")
if not cap.isOpened():
print("无法打开视频源")
return
while True:
ret, frame = cap.read()
if not ret:
break
# 进行OCR
results = ocr.ocr(frame)
# 绘制结果
annotated_frame = ocr.draw_results(frame, results)
# 显示结果
cv2.imshow('Real-time OCR', annotated_frame)
# 按'q'退出
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
else:
print("请指定输入源: --image, --video 或 --camera")
finally:
ocr.release()
if __name__ == "__main__":
# 如果直接运行脚本,提供示例用法
'''
启动命令示例
python 010使用PaddleOCR_rknn.py --det_model ./text_detection.rknn --rec_model ./text_recognition.rknn --image ./image_test/632e474452d560edd7004f745319ff00_frame_000730.jpg --output ./result.jpg
注:
导出的额rknn模型没有进行归一化, 归一化参数mean=0,std=1
'''
if len(os.sys.argv) == 1:
print("RK3588 PaddleOCR RKNN推理程序")
print("\n使用示例:")
print("# 处理单张图像")
print("python rk3588_ocr.py \\")
print(" --det_model ./rknn_models/text_detection.rknn \\")
print(" --rec_model ./rknn_models/text_recognition.rknn \\")
print(" --image ./test.jpg \\")
print(" --output ./result.jpg \\")
print(" --show")
print()
print("# 实时摄像头OCR")
print("python rk3588_ocr.py \\")
print(" --det_model ./rknn_models/text_detection.rknn \\")
print(" --rec_model ./rknn_models/text_recognition.rknn \\")
print(" --camera 0")
print()
print("# 处理视频文件")
print("python rk3588_ocr.py \\")
print(" --det_model ./rknn_models/text_detection.rknn \\")
print(" --rec_model ./rknn_models/text_recognition.rknn \\")
print(" --video ./input_video.mp4")
else:
main()