convert-the-model-to-rknn/008测试PadleCOR_ONNX固定输入形状模型全流程.py

398 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import yaml
import numpy as np
import onnxruntime as ort
from PIL import Image, ImageDraw, ImageFont
import math
class PaddleOCRONNX:
def __init__(self, det_model_path, rec_model_path):
"""
初始化ONNX推理器
Args:
det_model_path: 检测模型路径 (det.onnx)
rec_model_path: 识别模型路径 (rec.onnx)
"""
# 初始化检测模型
self.det_session = ort.InferenceSession(det_model_path)
self.det_input_name = self.det_session.get_inputs()[0].name
# 初始化识别模型
self.rec_session = ort.InferenceSession(rec_model_path)
self.rec_input_name = self.rec_session.get_inputs()[0].name
# 字符集(根据您的模型调整)
# self.character = ['blank', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+',
# ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8',
# '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E',
# 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
# 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
# '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
# 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
# 'z', '{', '|', '}', '~'] + [chr(i) for i in range(19968, 40870)] # 中文字符
self.character = self.get_dict()
if self.character is None:
raise ValueError('请检查字典文件是否存在!')
def get_dict(self, dict_path='./dict.yaml'):
"""
加载字典
"""
with open(dict_path, 'r', encoding='utf-8') as f:
dict_rec = yaml.safe_load(f)
return dict_rec.get('character_dict', [])
def resize_norm_img_det(self, img, input_shape=(640, 640)):
"""
检测模型的图像预处理 - 固定输入形状 [1, 3, 640, 640]
"""
h, w, _ = img.shape
target_h, target_w = input_shape
# 计算缩放比例 - 保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_img = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用灰色填充
padded_img = np.ones((target_h, target_w, 3), dtype=np.float32) * 114.0 # 直接用float32
# 计算居中位置
top = (target_h - new_h) // 2
left = (target_w - new_w) // 2
# 将缩放后的图像放到居中位置
padded_img[top:top+new_h, left:left+new_w] = resized_img.astype(np.float32)
# 归一化
img = (padded_img / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
img = img.transpose(2, 0, 1).astype(np.float32)
img = np.expand_dims(img, axis=0).astype(np.float32)
return img, ratio, (top, left)
def post_process_det(self, dt_boxes, ratio, padding_info, ori_shape):
"""
检测结果后处理 - 适配固定输入形状
"""
if dt_boxes is None:
return None
ori_h, ori_w = ori_shape
top, left = padding_info
# 将坐标从模型输出空间转换回原图空间
dt_boxes[:, :, 0] = (dt_boxes[:, :, 0] - left) / ratio
dt_boxes[:, :, 1] = (dt_boxes[:, :, 1] - top) / ratio
# 裁剪到原图范围内
dt_boxes[:, :, 0] = np.clip(dt_boxes[:, :, 0], 0, ori_w)
dt_boxes[:, :, 1] = np.clip(dt_boxes[:, :, 1], 0, ori_h)
return dt_boxes
def boxes_from_bitmap(self, pred, bitmap, dest_width, dest_height, max_candidates=1000, box_thresh=0.6):
"""
从位图中提取文本框
"""
bitmap = bitmap.astype(np.uint8)
height, width = bitmap.shape
# 查找轮廓
contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
num_contours = min(len(contours), max_candidates)
boxes = []
scores = []
for i in range(num_contours):
contour = contours[i]
points, sside = self.get_mini_boxes(contour)
if sside < 5:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if box_thresh > score:
continue
# 扩展box
box = self.unclip(points, 1.5).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < 5 + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(box[:, 0] / width * dest_width, 0, dest_width)
box[:, 1] = np.clip(box[:, 1] / height * dest_height, 0, dest_height)
boxes.append(box.astype(np.int16))
scores.append(score)
return np.array(boxes), scores
def get_mini_boxes(self, contour):
"""获取最小外接矩形"""
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""快速计算box得分"""
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def unclip(self, box, unclip_ratio):
"""扩展文本框"""
from shapely.geometry import Polygon
import pyclipper
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = offset.Execute(distance)
if len(expanded) == 0:
return box
else:
return np.array(expanded[0])
def resize_norm_img_rec(self, img, input_shape=(320, 48)):
"""
识别模型的图像预处理 - 固定输入形状 [1, 3, 48, 320]
"""
target_w, target_h = input_shape # 注意:宽度在前
h, w = img.shape[:2]
# 计算缩放比例,保持宽高比
ratio_h = target_h / h
ratio_w = target_w / w
ratio = min(ratio_h, ratio_w)
# 计算缩放后的尺寸
new_h = int(h * ratio)
new_w = int(w * ratio)
# 调整图像大小
resized_image = cv2.resize(img, (new_w, new_h))
# 创建目标尺寸的图像,用黑色填充
padded_image = np.zeros((target_h, target_w, 3), dtype=np.float32) # 直接用float32
# 将缩放后的图像放到左上角(识别模型通常左对齐)
padded_image[:new_h, :new_w] = resized_image.astype(np.float32)
# 归一化
padded_image = (padded_image / 255.0 - np.array([0.485, 0.456, 0.406], dtype=np.float32)) / np.array([0.229, 0.224, 0.225], dtype=np.float32)
padded_image = padded_image.transpose((2, 0, 1)).astype(np.float32)
return np.expand_dims(padded_image, axis=0).astype(np.float32)
def decode_rec_result(self, preds_prob):
"""
解码识别结果
"""
preds_idx = np.argmax(preds_prob, axis=1)
preds_prob = np.max(preds_prob, axis=1)
# CTC解码
last_idx = 0
preds_text = []
preds_conf = []
for i, idx in enumerate(preds_idx):
if idx != last_idx and idx != 0: # 0是blank
if idx < len(self.character):
preds_text.append(self.character[idx])
preds_conf.append(preds_prob[i])
last_idx = idx
text = ''.join(preds_text)
conf = np.mean(preds_conf) if preds_conf else 0.0
return text, conf
def detect_text(self, image):
"""
文本检测 - 适配固定输入形状 [1, 3, 640, 640]
"""
ori_h, ori_w = image.shape[:2]
# 预处理
det_img, ratio, padding_info = self.resize_norm_img_det(image)
# 推理
det_output = self.det_session.run(None, {self.det_input_name: det_img})[0]
# 后处理
mask = det_output[0, 0, :, :]
threshold = 0.3
bitmap = (mask > threshold).astype(np.uint8) * 255
# 从位图中提取文本框坐标是在640x640空间中的
boxes, scores = self.boxes_from_bitmap(mask, bitmap, 640, 640)
# 将坐标转换回原图空间
if len(boxes) > 0:
boxes = self.post_process_det(boxes, ratio, padding_info, (ori_h, ori_w))
return boxes, scores
def recognize_text(self, image):
"""
文本识别
"""
# 预处理
rec_img = self.resize_norm_img_rec(image)
# 推理
rec_output = self.rec_session.run(None, {self.rec_input_name: rec_img})[0]
# 解码
text, conf = self.decode_rec_result(rec_output[0])
return text, conf
def get_rotate_crop_image(self, img, points):
"""
根据四个点坐标裁剪并矫正图像
"""
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def ocr(self, image_path):
"""
完整的OCR流程
"""
# 读取图像
image = cv2.imread(image_path)
if image is None:
return []
# 1. 文本检测
dt_boxes, scores = self.detect_text(image)
if dt_boxes is None or len(dt_boxes) == 0:
return []
# 2. 文本识别
ocr_results = []
for i, box in enumerate(dt_boxes):
# 裁剪文本区域
box_points = box.astype(np.float32)
crop_img = self.get_rotate_crop_image(image, box_points)
# 识别文本
text, conf = self.recognize_text(crop_img)
if conf > 0.5: # 置信度过滤
ocr_results.append({
'text': text,
'confidence': conf,
'box': box.tolist(),
'score': scores[i] if i < len(scores) else 0.0
})
return ocr_results
# 使用示例
def main():
# 初始化OCR
# ocr = PaddleOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape.onnx')
ocr = PaddleOCRONNX('/home/admin-root/haotian/康达瑞贝斯机器狗/det_shape_20250814.onnx', '/home/admin-root/haotian/康达瑞贝斯机器狗/rec_shape_20250815.onnx')
# 执行OCR
image_path = '/home/admin-root/haotian/康达瑞贝斯机器狗/data_image/001读表图片/3aee64cc1f90d93a5a45979f7b17cb4b_frame_001460.jpg'
results = ocr.ocr(image_path)
# 打印结果
for result in results:
print(f"文本: {result['text']}")
print(f"置信度: {result['confidence']:.3f}")
print(f"检测得分: {result['score']:.3f}")
print(f"坐标: {result['box']}")
print("-" * 50)
# 可视化结果
visualize_results(image_path, results)
def visualize_results(image_path, results):
"""
可视化OCR结果
"""
image = cv2.imread(image_path)
for result in results:
box = np.array(result['box'], dtype=np.int32)
cv2.polylines(image, [box], True, (0, 255, 0), 2)
# 在框上方显示文本
cv2.putText(image, result['text'],
(box[0][0], box[0][1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
cv2.imwrite('result_shape_20250815.jpg', image)
if __name__ == "__main__":
main()