263 lines
8.5 KiB
C++
263 lines
8.5 KiB
C++
/**
|
||
* sliding_window_detector.h - 通用滑动窗口检测工具
|
||
*
|
||
* 功能:
|
||
* 1. 支持配置多窗口或自动计算窗口
|
||
* 2. 从源图裁剪窗口、resize 到模型输入尺寸
|
||
* 3. 检测结果坐标映射回原图
|
||
* 4. NMS 合并多窗口结果
|
||
*
|
||
* 用法:
|
||
* SlidingWindowDetector swd;
|
||
* swd.Init(config); // 从配置初始化窗口
|
||
* auto windows = swd.GetWindows(src_w, src_h);
|
||
* for (auto& win : windows) {
|
||
* auto input = swd.PrepareInput(frame, win, model_w, model_h);
|
||
* // ... 推理 ...
|
||
* auto dets = swd.MapDetectionsToOriginal(raw_dets, win, model_w, model_h);
|
||
* }
|
||
*/
|
||
|
||
#pragma once
|
||
|
||
#include <algorithm>
|
||
#include <cstdint>
|
||
#include <cstring>
|
||
#include <vector>
|
||
|
||
#include "frame/frame.h"
|
||
#include "utils/simple_json.h"
|
||
|
||
namespace rk3588 {
|
||
|
||
// 窗口定义
|
||
struct DetectionWindow {
|
||
int x = 0;
|
||
int y = 0;
|
||
int w = 640;
|
||
int h = 640;
|
||
|
||
bool IsValid() const { return w > 0 && h > 0; }
|
||
};
|
||
|
||
// 检测框(通用格式)
|
||
struct DetectionBox {
|
||
float x, y, w, h; // 左上角 + 宽高
|
||
float confidence;
|
||
int class_id;
|
||
};
|
||
|
||
/**
|
||
* 滑动窗口检测器
|
||
*/
|
||
class SlidingWindowDetector {
|
||
public:
|
||
SlidingWindowDetector() = default;
|
||
|
||
/**
|
||
* 从配置初始化
|
||
* @param config SimpleJson 配置,支持 "windows" 数组
|
||
* @return 是否成功
|
||
*/
|
||
bool InitFromConfig(const SimpleJson& config) {
|
||
windows_.clear();
|
||
|
||
// 解析预配置窗口
|
||
if (const SimpleJson* win_arr = config.Find("windows"); win_arr && win_arr->IsArray()) {
|
||
for (const auto& w : win_arr->AsArray()) {
|
||
if (w.IsObject()) {
|
||
DetectionWindow win;
|
||
win.x = w.ValueOr<int>("x", 0);
|
||
win.y = w.ValueOr<int>("y", 0);
|
||
win.w = w.ValueOr<int>("w", 640);
|
||
win.h = w.ValueOr<int>("h", 640);
|
||
if (win.IsValid()) {
|
||
windows_.push_back(win);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 目标 resize 高度(用于自动计算窗口时)
|
||
target_height_ = config.ValueOr<int>("target_height", 640);
|
||
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 获取窗口列表(预配置或自动计算)
|
||
* @param src_w 源图宽度
|
||
* @param src_h 源图高度
|
||
* @return 窗口列表
|
||
*/
|
||
std::vector<DetectionWindow> GetWindows(int src_w, int src_h) const {
|
||
if (!windows_.empty()) {
|
||
return windows_;
|
||
}
|
||
return CalculateWindowsAuto(src_w, src_h);
|
||
}
|
||
|
||
/**
|
||
* 准备模型输入
|
||
* 从源图裁剪窗口区域,resize 到模型输入尺寸
|
||
*
|
||
* @param frame 源帧
|
||
* @param win 窗口定义
|
||
* @param model_w 模型输入宽
|
||
* @param model_h 模型输入高
|
||
* @param output 输出缓冲区(model_w * model_h * 3)
|
||
* @return 是否成功
|
||
*/
|
||
bool PrepareInput(const FramePtr& frame,
|
||
const DetectionWindow& win,
|
||
int model_w, int model_h,
|
||
uint8_t* output) const {
|
||
if (!frame || !frame->data || !output) return false;
|
||
|
||
const int src_w = frame->width;
|
||
const int src_h = frame->height;
|
||
|
||
// 获取源数据指针
|
||
const uint8_t* src = frame->planes[0].data ? frame->planes[0].data : frame->data;
|
||
const int src_stride = frame->planes[0].stride > 0 ? frame->planes[0].stride
|
||
: (frame->stride > 0 ? frame->stride : frame->width * 3);
|
||
|
||
// 限制窗口在源图范围内
|
||
int win_x = std::max(0, std::min(win.x, src_w - 1));
|
||
int win_y = std::max(0, std::min(win.y, src_h - 1));
|
||
int win_w = std::min(win.w, src_w - win_x);
|
||
int win_h = std::min(win.h, src_h - win_y);
|
||
|
||
if (win_w <= 0 || win_h <= 0) return false;
|
||
|
||
// 裁剪窗口
|
||
std::vector<uint8_t> crop_buf(static_cast<size_t>(win_w) * win_h * 3);
|
||
for (int row = 0; row < win_h; ++row) {
|
||
const uint8_t* src_row = src + (win_y + row) * src_stride + win_x * 3;
|
||
uint8_t* dst_row = crop_buf.data() + row * win_w * 3;
|
||
memcpy(dst_row, src_row, static_cast<size_t>(win_w) * 3);
|
||
}
|
||
|
||
// Resize 到模型输入尺寸
|
||
ResizeRgbBilinear(crop_buf.data(), win_w, win_h, win_w * 3,
|
||
output, model_w, model_h, false);
|
||
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 将检测结果从模型坐标映射回原图坐标
|
||
*
|
||
* @param detections 模型输出的检测框(在 model_w x model_h 坐标系中)
|
||
* @param win 窗口定义
|
||
* @param model_w 模型输入宽
|
||
* @param model_h 模型输入高
|
||
* @return 映射后的检测框(在原图坐标系中)
|
||
*/
|
||
std::vector<DetectionBox> MapDetectionsToOriginal(
|
||
const std::vector<DetectionBox>& detections,
|
||
const DetectionWindow& win,
|
||
int model_w, int model_h) const {
|
||
|
||
std::vector<DetectionBox> mapped = detections;
|
||
|
||
float scale_x = static_cast<float>(win.w) / static_cast<float>(model_w);
|
||
float scale_y = static_cast<float>(win.h) / static_cast<float>(model_h);
|
||
|
||
for (auto& det : mapped) {
|
||
det.x = win.x + det.x * scale_x;
|
||
det.y = win.y + det.y * scale_y;
|
||
det.w *= scale_x;
|
||
det.h *= scale_y;
|
||
}
|
||
|
||
return mapped;
|
||
}
|
||
|
||
/**
|
||
* 获取预配置窗口数量
|
||
*/
|
||
size_t GetConfiguredWindowCount() const {
|
||
return windows_.size();
|
||
}
|
||
|
||
private:
|
||
/**
|
||
* 自动计算窗口(覆盖整个图像)
|
||
* 策略:生成重叠的 640x640 窗口网格
|
||
*/
|
||
std::vector<DetectionWindow> CalculateWindowsAuto(int src_w, int src_h) const {
|
||
std::vector<DetectionWindow> windows;
|
||
|
||
const int win_size = 640;
|
||
|
||
// 计算步长(带重叠)
|
||
int step_x = (src_w <= win_size) ? src_w : (src_w - win_size) / ((src_w + win_size - 1) / win_size - 1);
|
||
int step_y = (src_h <= win_size) ? src_h : (src_h - win_size) / ((src_h + win_size - 1) / win_size - 1);
|
||
|
||
if (step_x < win_size) step_x = win_size;
|
||
if (step_y < win_size) step_y = win_size;
|
||
|
||
for (int y = 0; y < src_h; y += step_y) {
|
||
for (int x = 0; x < src_w; x += step_x) {
|
||
DetectionWindow win;
|
||
win.x = x;
|
||
win.y = y;
|
||
win.w = win_size;
|
||
win.h = win_size;
|
||
windows.push_back(win);
|
||
|
||
if (x + win_size >= src_w) break;
|
||
}
|
||
if (y + win_size >= src_h) break;
|
||
}
|
||
|
||
return windows;
|
||
}
|
||
|
||
/**
|
||
* RGB 图像双线性 resize
|
||
* @param swap_rb 是否交换 R/B 通道
|
||
*/
|
||
static void ResizeRgbBilinear(const uint8_t* src, int src_w, int src_h, int src_stride,
|
||
uint8_t* dst, int dst_w, int dst_h, bool swap_rb) {
|
||
const float scale_x = static_cast<float>(src_w) / dst_w;
|
||
const float scale_y = static_cast<float>(src_h) / dst_h;
|
||
|
||
for (int y = 0; y < dst_h; ++y) {
|
||
float fy = y * scale_y;
|
||
int y0 = static_cast<int>(fy);
|
||
int y1 = std::min(y0 + 1, src_h - 1);
|
||
float dy = fy - y0;
|
||
|
||
for (int x = 0; x < dst_w; ++x) {
|
||
float fx = x * scale_x;
|
||
int x0 = static_cast<int>(fx);
|
||
int x1 = std::min(x0 + 1, src_w - 1);
|
||
float dx = fx - x0;
|
||
|
||
// 双线性插值
|
||
for (int c = 0; c < 3; ++c) {
|
||
int src_c = swap_rb ? (2 - c) : c;
|
||
|
||
float v00 = src[(y0 * src_stride) + (x0 * 3) + src_c];
|
||
float v01 = src[(y0 * src_stride) + (x1 * 3) + src_c];
|
||
float v10 = src[(y1 * src_stride) + (x0 * 3) + src_c];
|
||
float v11 = src[(y1 * src_stride) + (x1 * 3) + src_c];
|
||
|
||
float v0 = v00 * (1 - dx) + v01 * dx;
|
||
float v1 = v10 * (1 - dx) + v11 * dx;
|
||
float v = v0 * (1 - dy) + v1 * dy;
|
||
|
||
dst[(y * dst_w + x) * 3 + c] = static_cast<uint8_t>(v);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
std::vector<DetectionWindow> windows_;
|
||
int target_height_ = 640;
|
||
};
|
||
|
||
} // namespace rk3588
|