OrangePi3588Media/include/face/scrfd_detector.h

86 lines
1.7 KiB
C++

#pragma once
/**
* SCRFD Detector - 可复用的 SCRFD 检测器
* 供 ai_scrfd 和 ai_scrfd_zoned 节点使用
*/
#include <vector>
#include <cstdint>
#include "face/face_result.h"
// 包含 AiScheduler 以使用 BorrowedOutput
#include "ai_scheduler.h"
namespace rk3588 {
/**
* SCRFD 检测结果
*/
struct ScrfdDetection {
FaceDetItem item;
};
/**
* SCRFD 检测器配置
*/
struct ScrfdConfig {
float conf_thresh = 0.5f;
float nms_thresh = 0.4f;
int max_faces = 50;
bool output_landmarks = true;
};
/**
* SCRFD 检测器
*
* 使用示例:
* ScrfdDetector det;
* det.Init(640, 640);
* auto dets = det.Decode(outputs, src_w, src_h, config);
*/
class ScrfdDetector {
public:
ScrfdDetector();
~ScrfdDetector();
/**
* 初始化检测器
* @param model_w 模型输入宽度 (640)
* @param model_h 模型输入高度 (640)
*/
void Init(int model_w, int model_h);
/**
* 解码 SCRFD 输出
* @param outputs 9个输出张量 (BorrowedOutput)
* @param src_w 原始图像宽度
* @param src_h 原始图像高度
* @param cfg 检测配置
* @return 检测结果列表
*/
std::vector<FaceDetItem> Decode(
const std::vector<AiScheduler::BorrowedOutput>& outputs,
int src_w, int src_h,
const ScrfdConfig& cfg);
/**
* 应用 NMS
*/
std::vector<FaceDetItem> ApplyNMS(
std::vector<FaceDetItem>& dets,
float nms_thresh);
private:
struct CenterPoint {
float cx, cy;
float stride;
};
std::vector<CenterPoint> center_points_;
int model_w_ = 640;
int model_h_ = 640;
};
} // namespace rk3588