OrangePi3588Media/include/utils/sliding_window_detector.h

263 lines
8.5 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* sliding_window_detector.h - 通用滑动窗口检测工具
*
* 功能:
* 1. 支持配置多窗口或自动计算窗口
* 2. 从源图裁剪窗口、resize 到模型输入尺寸
* 3. 检测结果坐标映射回原图
* 4. NMS 合并多窗口结果
*
* 用法:
* SlidingWindowDetector swd;
* swd.Init(config); // 从配置初始化窗口
* auto windows = swd.GetWindows(src_w, src_h);
* for (auto& win : windows) {
* auto input = swd.PrepareInput(frame, win, model_w, model_h);
* // ... 推理 ...
* auto dets = swd.MapDetectionsToOriginal(raw_dets, win, model_w, model_h);
* }
*/
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <vector>
#include "frame/frame.h"
#include "utils/simple_json.h"
namespace rk3588 {
// 窗口定义
struct DetectionWindow {
int x = 0;
int y = 0;
int w = 640;
int h = 640;
bool IsValid() const { return w > 0 && h > 0; }
};
// 检测框(通用格式)
struct DetectionBox {
float x, y, w, h; // 左上角 + 宽高
float confidence;
int class_id;
};
/**
* 滑动窗口检测器
*/
class SlidingWindowDetector {
public:
SlidingWindowDetector() = default;
/**
* 从配置初始化
* @param config SimpleJson 配置,支持 "windows" 数组
* @return 是否成功
*/
bool InitFromConfig(const SimpleJson& config) {
windows_.clear();
// 解析预配置窗口
if (const SimpleJson* win_arr = config.Find("windows"); win_arr && win_arr->IsArray()) {
for (const auto& w : win_arr->AsArray()) {
if (w.IsObject()) {
DetectionWindow win;
win.x = w.ValueOr<int>("x", 0);
win.y = w.ValueOr<int>("y", 0);
win.w = w.ValueOr<int>("w", 640);
win.h = w.ValueOr<int>("h", 640);
if (win.IsValid()) {
windows_.push_back(win);
}
}
}
}
// 目标 resize 高度(用于自动计算窗口时)
target_height_ = config.ValueOr<int>("target_height", 640);
return true;
}
/**
* 获取窗口列表(预配置或自动计算)
* @param src_w 源图宽度
* @param src_h 源图高度
* @return 窗口列表
*/
std::vector<DetectionWindow> GetWindows(int src_w, int src_h) const {
if (!windows_.empty()) {
return windows_;
}
return CalculateWindowsAuto(src_w, src_h);
}
/**
* 准备模型输入
* 从源图裁剪窗口区域resize 到模型输入尺寸
*
* @param frame 源帧
* @param win 窗口定义
* @param model_w 模型输入宽
* @param model_h 模型输入高
* @param output 输出缓冲区model_w * model_h * 3
* @return 是否成功
*/
bool PrepareInput(const FramePtr& frame,
const DetectionWindow& win,
int model_w, int model_h,
uint8_t* output) const {
if (!frame || !frame->data || !output) return false;
const int src_w = frame->width;
const int src_h = frame->height;
// 获取源数据指针
const uint8_t* src = frame->planes[0].data ? frame->planes[0].data : frame->data;
const int src_stride = frame->planes[0].stride > 0 ? frame->planes[0].stride
: (frame->stride > 0 ? frame->stride : frame->width * 3);
// 限制窗口在源图范围内
int win_x = std::max(0, std::min(win.x, src_w - 1));
int win_y = std::max(0, std::min(win.y, src_h - 1));
int win_w = std::min(win.w, src_w - win_x);
int win_h = std::min(win.h, src_h - win_y);
if (win_w <= 0 || win_h <= 0) return false;
// 裁剪窗口
std::vector<uint8_t> crop_buf(static_cast<size_t>(win_w) * win_h * 3);
for (int row = 0; row < win_h; ++row) {
const uint8_t* src_row = src + (win_y + row) * src_stride + win_x * 3;
uint8_t* dst_row = crop_buf.data() + row * win_w * 3;
memcpy(dst_row, src_row, static_cast<size_t>(win_w) * 3);
}
// Resize 到模型输入尺寸
ResizeRgbBilinear(crop_buf.data(), win_w, win_h, win_w * 3,
output, model_w, model_h, false);
return true;
}
/**
* 将检测结果从模型坐标映射回原图坐标
*
* @param detections 模型输出的检测框(在 model_w x model_h 坐标系中)
* @param win 窗口定义
* @param model_w 模型输入宽
* @param model_h 模型输入高
* @return 映射后的检测框(在原图坐标系中)
*/
std::vector<DetectionBox> MapDetectionsToOriginal(
const std::vector<DetectionBox>& detections,
const DetectionWindow& win,
int model_w, int model_h) const {
std::vector<DetectionBox> mapped = detections;
float scale_x = static_cast<float>(win.w) / static_cast<float>(model_w);
float scale_y = static_cast<float>(win.h) / static_cast<float>(model_h);
for (auto& det : mapped) {
det.x = win.x + det.x * scale_x;
det.y = win.y + det.y * scale_y;
det.w *= scale_x;
det.h *= scale_y;
}
return mapped;
}
/**
* 获取预配置窗口数量
*/
size_t GetConfiguredWindowCount() const {
return windows_.size();
}
private:
/**
* 自动计算窗口(覆盖整个图像)
* 策略:生成重叠的 640x640 窗口网格
*/
std::vector<DetectionWindow> CalculateWindowsAuto(int src_w, int src_h) const {
std::vector<DetectionWindow> windows;
const int win_size = 640;
// 计算步长(带重叠)
int step_x = (src_w <= win_size) ? src_w : (src_w - win_size) / ((src_w + win_size - 1) / win_size - 1);
int step_y = (src_h <= win_size) ? src_h : (src_h - win_size) / ((src_h + win_size - 1) / win_size - 1);
if (step_x < win_size) step_x = win_size;
if (step_y < win_size) step_y = win_size;
for (int y = 0; y < src_h; y += step_y) {
for (int x = 0; x < src_w; x += step_x) {
DetectionWindow win;
win.x = x;
win.y = y;
win.w = win_size;
win.h = win_size;
windows.push_back(win);
if (x + win_size >= src_w) break;
}
if (y + win_size >= src_h) break;
}
return windows;
}
/**
* RGB 图像双线性 resize
* @param swap_rb 是否交换 R/B 通道
*/
static void ResizeRgbBilinear(const uint8_t* src, int src_w, int src_h, int src_stride,
uint8_t* dst, int dst_w, int dst_h, bool swap_rb) {
const float scale_x = static_cast<float>(src_w) / dst_w;
const float scale_y = static_cast<float>(src_h) / dst_h;
for (int y = 0; y < dst_h; ++y) {
float fy = y * scale_y;
int y0 = static_cast<int>(fy);
int y1 = std::min(y0 + 1, src_h - 1);
float dy = fy - y0;
for (int x = 0; x < dst_w; ++x) {
float fx = x * scale_x;
int x0 = static_cast<int>(fx);
int x1 = std::min(x0 + 1, src_w - 1);
float dx = fx - x0;
// 双线性插值
for (int c = 0; c < 3; ++c) {
int src_c = swap_rb ? (2 - c) : c;
float v00 = src[(y0 * src_stride) + (x0 * 3) + src_c];
float v01 = src[(y0 * src_stride) + (x1 * 3) + src_c];
float v10 = src[(y1 * src_stride) + (x0 * 3) + src_c];
float v11 = src[(y1 * src_stride) + (x1 * 3) + src_c];
float v0 = v00 * (1 - dx) + v01 * dx;
float v1 = v10 * (1 - dx) + v11 * dx;
float v = v0 * (1 - dy) + v1 * dy;
dst[(y * dst_w + x) * 3 + c] = static_cast<uint8_t>(v);
}
}
}
}
std::vector<DetectionWindow> windows_;
int target_height_ = 640;
};
} // namespace rk3588