86 lines
1.7 KiB
C++
86 lines
1.7 KiB
C++
#pragma once
|
|
|
|
/**
|
|
* SCRFD Detector - 可复用的 SCRFD 检测器
|
|
* 供 ai_scrfd 和 ai_scrfd_zoned 节点使用
|
|
*/
|
|
|
|
#include <vector>
|
|
#include <cstdint>
|
|
#include "face/face_result.h"
|
|
|
|
// 包含 AiScheduler 以使用 BorrowedOutput
|
|
#include "ai_scheduler.h"
|
|
|
|
namespace rk3588 {
|
|
|
|
/**
|
|
* SCRFD 检测结果
|
|
*/
|
|
struct ScrfdDetection {
|
|
FaceDetItem item;
|
|
};
|
|
|
|
/**
|
|
* SCRFD 检测器配置
|
|
*/
|
|
struct ScrfdConfig {
|
|
float conf_thresh = 0.5f;
|
|
float nms_thresh = 0.4f;
|
|
int max_faces = 50;
|
|
bool output_landmarks = true;
|
|
};
|
|
|
|
/**
|
|
* SCRFD 检测器
|
|
*
|
|
* 使用示例:
|
|
* ScrfdDetector det;
|
|
* det.Init(640, 640);
|
|
* auto dets = det.Decode(outputs, src_w, src_h, config);
|
|
*/
|
|
class ScrfdDetector {
|
|
public:
|
|
ScrfdDetector();
|
|
~ScrfdDetector();
|
|
|
|
/**
|
|
* 初始化检测器
|
|
* @param model_w 模型输入宽度 (640)
|
|
* @param model_h 模型输入高度 (640)
|
|
*/
|
|
void Init(int model_w, int model_h);
|
|
|
|
/**
|
|
* 解码 SCRFD 输出
|
|
* @param outputs 9个输出张量 (BorrowedOutput)
|
|
* @param src_w 原始图像宽度
|
|
* @param src_h 原始图像高度
|
|
* @param cfg 检测配置
|
|
* @return 检测结果列表
|
|
*/
|
|
std::vector<FaceDetItem> Decode(
|
|
const std::vector<AiScheduler::BorrowedOutput>& outputs,
|
|
int src_w, int src_h,
|
|
const ScrfdConfig& cfg);
|
|
|
|
/**
|
|
* 应用 NMS
|
|
*/
|
|
std::vector<FaceDetItem> ApplyNMS(
|
|
std::vector<FaceDetItem>& dets,
|
|
float nms_thresh);
|
|
|
|
private:
|
|
struct CenterPoint {
|
|
float cx, cy;
|
|
float stride;
|
|
};
|
|
|
|
std::vector<CenterPoint> center_points_;
|
|
int model_w_ = 640;
|
|
int model_h_ = 640;
|
|
};
|
|
|
|
} // namespace rk3588
|