OrangePi3588Media/plugins/ai_scrfd/ai_scrfd_node.cpp

399 lines
13 KiB
C++

/**
* ai_scrfd - SCRFD 640x640 face detection node for RK3588
*
* Reference: https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/scrfd.cpp
* BBox format: [left, top, right, bottom] offsets from grid center
*/
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "face/face_result.h"
#include "hw/i_infer_backend.h"
#include "node.h"
#include "utils/dma_alloc.h"
#include "utils/logger.h"
namespace rk3588 {
struct ScrfdConfig {
float conf_thresh = 0.5f;
float nms_thresh = 0.4f;
int max_faces = 50;
bool output_landmarks = true;
std::string input_format = "rgb";
};
// Grid center point for each anchor
struct CenterPoint {
float cx, cy; // grid coordinates (0,0), (1,0), ...
float stride;
};
class AiScrfdNode : public INode {
public:
std::string Id() const override { return id_; }
std::string Type() const override { return "ai_scrfd"; }
bool Init(const SimpleJson& config, const NodeContext& ctx) override {
id_ = config.ValueOr<std::string>("id", "scrfd");
model_path_ = config.ValueOr<std::string>("model_path", "");
if (model_path_.empty()) {
LogError("[ai_scrfd] model_path is required");
return false;
}
cfg_.conf_thresh = config.ValueOr<float>("conf_thresh", 0.5f);
cfg_.nms_thresh = config.ValueOr<float>("nms_thresh", 0.4f);
cfg_.max_faces = config.ValueOr<int>("max_faces", 50);
cfg_.output_landmarks = config.ValueOr<bool>("output_landmarks", true);
cfg_.input_format = config.ValueOr<std::string>("input_format", "rgb");
model_w_ = 640;
model_h_ = 640;
// Generate center points for all anchors
GenerateCenterPoints();
input_queue_ = ctx.input_queue;
output_queues_ = ctx.output_queues;
if (!input_queue_) {
LogError("[ai_scrfd] no input queue");
return false;
}
infer_backend_ = ctx.infer_backend;
if (!infer_backend_) {
LogError("[ai_scrfd] no infer backend");
return false;
}
#if defined(RK3588_ENABLE_RKNN)
std::string err;
model_handle_ = infer_backend_->LoadModel(model_path_, err);
if (model_handle_ == kInvalidModelHandle) {
LogError("[ai_scrfd] failed to load model: " + err);
return false;
}
input_buf_.resize(model_w_ * model_h_ * 3);
LogInfo("[ai_scrfd] model loaded: " + model_path_);
#else
LogWarn("[ai_scrfd] RKNN disabled");
#endif
return true;
}
bool Start() override {
LogInfo("[ai_scrfd] start");
return true;
}
void Stop() override {
#if defined(RK3588_ENABLE_RKNN)
if (model_handle_ != kInvalidModelHandle) {
infer_backend_->UnloadModel(model_handle_);
model_handle_ = kInvalidModelHandle;
}
#endif
LogInfo("[ai_scrfd] stop");
}
NodeStatus Process(FramePtr frame) override {
if (!frame) return NodeStatus::DROP;
#if defined(RK3588_ENABLE_RKNN)
RunDetection(frame);
#endif
Push(frame);
return NodeStatus::OK;
}
private:
void Push(FramePtr frame) {
for (auto& q : output_queues_) q->Push(frame);
}
#if defined(RK3588_ENABLE_RKNN)
void GenerateCenterPoints() {
// strides: 8, 16, 32
const int strides[] = {8, 16, 32};
for (int stride : strides) {
int num_grid = model_w_ / stride;
for (int y = 0; y < num_grid; ++y) {
for (int x = 0; x < num_grid; ++x) {
// 2 anchors per location
for (int a = 0; a < 2; ++a) {
CenterPoint pt;
pt.cx = static_cast<float>(x); // grid x
pt.cy = static_cast<float>(y); // grid y
pt.stride = static_cast<float>(stride);
center_points_.push_back(pt);
}
}
}
}
LogInfo("[ai_scrfd] Generated " + std::to_string(center_points_.size()) + " center points");
}
void RunDetection(FramePtr frame) {
if (!frame->data || frame->data_size == 0) return;
const int src_w = frame->width;
const int src_h = frame->height;
if (frame->DmaFd() >= 0) frame->SyncStart();
PrepareInput(frame, model_w_, model_h_);
InferInput input;
input.width = model_w_;
input.height = model_h_;
input.is_nhwc = true;
input.data = input_buf_.data();
input.size = input_buf_.size();
input.type = RKNN_TENSOR_UINT8;
auto r = infer_backend_->InferBorrowed(model_handle_, input);
if (!r.success) {
LogWarn("[ai_scrfd] inference failed: " + r.error);
return;
}
std::vector<FaceDetItem> detections = ParseOutputs(r.outputs, src_w, src_h);
detections = ApplyNMS(detections, cfg_.nms_thresh);
if (detections.size() > static_cast<size_t>(cfg_.max_faces)) {
detections.resize(cfg_.max_faces);
}
FaceDetResult result;
result.img_w = src_w;
result.img_h = src_h;
result.model_name = "scrfd_640";
result.faces = std::move(detections);
frame->face_det = std::make_shared<FaceDetResult>(std::move(result));
}
void PrepareInput(FramePtr frame, int dst_w, int dst_h) {
const uint8_t* src = frame->planes[0].data ? frame->planes[0].data : frame->data;
const int src_stride = frame->planes[0].stride > 0 ? frame->planes[0].stride
: (frame->stride > 0 ? frame->stride : frame->width * 3);
// Simple bilinear resize
ResizeBilinear(src, frame->width, frame->height, src_stride,
input_buf_.data(), dst_w, dst_h, dst_w * 3);
// RGB/BGR conversion if needed
bool need_swap = (frame->format == PixelFormat::BGR && cfg_.input_format == "rgb") ||
(frame->format == PixelFormat::RGB && cfg_.input_format == "bgr");
if (need_swap) {
for (int i = 0; i < dst_w * dst_h * 3; i += 3) {
std::swap(input_buf_[i], input_buf_[i + 2]);
}
}
}
void ResizeBilinear(const uint8_t* src, int src_w, int src_h, int src_stride,
uint8_t* dst, int dst_w, int dst_h, int dst_stride) {
float x_ratio = static_cast<float>(src_w) / dst_w;
float y_ratio = static_cast<float>(src_h) / dst_h;
for (int y = 0; y < dst_h; ++y) {
for (int x = 0; x < dst_w; ++x) {
float sx = (x + 0.5f) * x_ratio - 0.5f;
float sy = (y + 0.5f) * y_ratio - 0.5f;
int x0 = static_cast<int>(std::floor(sx));
int y0 = static_cast<int>(std::floor(sy));
int x1 = std::min(x0 + 1, src_w - 1);
int y1 = std::min(y0 + 1, src_h - 1);
x0 = std::max(0, x0);
y0 = std::max(0, y0);
float fx = sx - x0;
float fy = sy - y0;
for (int c = 0; c < 3; ++c) {
float v00 = src[y0 * src_stride + x0 * 3 + c];
float v01 = src[y0 * src_stride + x1 * 3 + c];
float v10 = src[y1 * src_stride + x0 * 3 + c];
float v11 = src[y1 * src_stride + x1 * 3 + c];
float v = v00 * (1-fx) * (1-fy) + v01 * fx * (1-fy) +
v10 * (1-fx) * fy + v11 * fx * fy;
dst[y * dst_stride + x * 3 + c] = static_cast<uint8_t>(v);
}
}
}
}
/**
* Parse SCRFD outputs - reference implementation from lite.ai.toolkit
*
* BBox format: [left, top, right, bottom] - distances from grid center
* NOT [dx, dy, dw, dh]!
*/
std::vector<FaceDetItem> ParseOutputs(
const std::vector<AiScheduler::BorrowedOutput>& outputs,
int src_w, int src_h) {
std::vector<FaceDetItem> detections;
if (outputs.size() != 9) return detections;
// Output order: score_8, score_16, score_32, bbox_8, bbox_16, bbox_32, kps_8, kps_16, kps_32
const int anchor_counts[] = {12800, 3200, 800};
const int strides[] = {8, 16, 32};
size_t anchor_idx = 0;
for (int s = 0; s < 3; ++s) {
int stride = strides[s];
int count = anchor_counts[s];
const auto& score_out = outputs[s];
const auto& bbox_out = outputs[s + 3];
const auto& kps_out = outputs[s + 6];
if (score_out.dims.size() < 3) continue;
const float* scores = reinterpret_cast<const float*>(score_out.data);
const float* bboxes = reinterpret_cast<const float*>(bbox_out.data);
const float* kps = reinterpret_cast<const float*>(kps_out.data);
if (!scores || !bboxes || !kps) continue;
for (int i = 0; i < count; ++i) {
if (anchor_idx >= center_points_.size()) break;
float score = scores[i];
if (score < cfg_.conf_thresh) {
anchor_idx++;
continue;
}
const CenterPoint& pt = center_points_[anchor_idx];
// BBox: [left, top, right, bottom] - distances from center
float left = bboxes[i * 4 + 0];
float top = bboxes[i * 4 + 1];
float right = bboxes[i * 4 + 2];
float bottom = bboxes[i * 4 + 3];
// Decode to image coordinates (640x640)
float x1_640 = (pt.cx - left) * stride;
float y1_640 = (pt.cy - top) * stride;
float x2_640 = (pt.cx + right) * stride;
float y2_640 = (pt.cy + bottom) * stride;
// Scale to original image size
float scale_x = static_cast<float>(src_w) / model_w_;
float scale_y = static_cast<float>(src_h) / model_h_;
FaceDetItem det;
det.bbox.x = x1_640 * scale_x;
det.bbox.y = y1_640 * scale_y;
det.bbox.w = (x2_640 - x1_640) * scale_x;
det.bbox.h = (y2_640 - y1_640) * scale_y;
det.score = score;
det.has_landmarks = cfg_.output_landmarks;
// Keypoints
if (cfg_.output_landmarks) {
for (int p = 0; p < 5; ++p) {
float kps_x = kps[i * 10 + p * 2 + 0];
float kps_y = kps[i * 10 + p * 2 + 1];
float kx_640 = (pt.cx + kps_x) * stride;
float ky_640 = (pt.cy + kps_y) * stride;
det.landmarks[p].x = kx_640 * scale_x;
det.landmarks[p].y = ky_640 * scale_y;
}
}
detections.push_back(det);
anchor_idx++;
}
}
return detections;
}
float IoU(const Rect& a, const Rect& b) {
float x1 = std::max(a.x, b.x);
float y1 = std::max(a.y, b.y);
float x2 = std::min(a.x + a.w, b.x + b.w);
float y2 = std::min(a.y + a.h, b.y + b.h);
float inter = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1);
float area_a = a.w * a.h;
float area_b = b.w * b.h;
float union_area = area_a + area_b - inter;
return union_area > 0 ? inter / union_area : 0;
}
std::vector<FaceDetItem> ApplyNMS(std::vector<FaceDetItem>& dets, float thresh) {
if (dets.empty()) return dets;
std::sort(dets.begin(), dets.end(),
[](const FaceDetItem& a, const FaceDetItem& b) {
return a.score > b.score;
});
std::vector<FaceDetItem> keep;
std::vector<bool> suppressed(dets.size(), false);
for (size_t i = 0; i < dets.size(); ++i) {
if (suppressed[i]) continue;
keep.push_back(dets[i]);
for (size_t j = i + 1; j < dets.size(); ++j) {
if (suppressed[j]) continue;
if (IoU(dets[i].bbox, dets[j].bbox) > thresh) {
suppressed[j] = true;
}
}
}
return keep;
}
#endif
std::string id_;
std::string model_path_;
ScrfdConfig cfg_;
int model_w_ = 640;
int model_h_ = 640;
std::vector<CenterPoint> center_points_;
std::shared_ptr<SpscQueue<FramePtr>> input_queue_;
std::vector<std::shared_ptr<SpscQueue<FramePtr>>> output_queues_;
std::shared_ptr<IInferBackend> infer_backend_;
#if defined(RK3588_ENABLE_RKNN)
ModelHandle model_handle_ = kInvalidModelHandle;
std::vector<uint8_t> input_buf_;
#endif
};
REGISTER_NODE(AiScrfdNode, "ai_scrfd");
} // namespace rk3588