#include "pose_assoc_node.h" #include #include #include #include #include #include "pose/pose_result.h" #include "utils/logger.h" namespace rk3588 { namespace { struct PoseAssocConfig { float min_iou = 0.1f; }; struct MatchCandidate { float iou = 0.0f; int pose_index = -1; int det_index = -1; }; bool ParseConfig(const SimpleJson& config, PoseAssocConfig& out, std::string& err) { out.min_iou = config.ValueOr("min_iou", 0.1f); if (out.min_iou < 0.0f || out.min_iou > 1.0f) { err = "min_iou must be in [0, 1]"; return false; } return true; } float IoU(const Rect& lhs, const Rect& rhs) { const float x1 = std::max(lhs.x, rhs.x); const float y1 = std::max(lhs.y, rhs.y); const float x2 = std::min(lhs.x + lhs.w, rhs.x + rhs.w); const float y2 = std::min(lhs.y + lhs.h, rhs.y + rhs.h); const float iw = std::max(0.0f, x2 - x1); const float ih = std::max(0.0f, y2 - y1); const float inter = iw * ih; const float area_l = std::max(0.0f, lhs.w) * std::max(0.0f, lhs.h); const float area_r = std::max(0.0f, rhs.w) * std::max(0.0f, rhs.h); const float uni = area_l + area_r - inter; return uni <= 0.0f ? 0.0f : (inter / uni); } void PushToDownstream(const std::vector>>& output_queues, const FramePtr& frame) { for (const auto& q : output_queues) { if (q) q->Push(frame); } } } // namespace struct PoseAssocNode::Impl { PoseAssocConfig config; std::string init_err; }; PoseAssocNode::PoseAssocNode() : impl_(std::make_unique()) {} PoseAssocNode::~PoseAssocNode() = default; std::string PoseAssocNode::Id() const { return id_; } std::string PoseAssocNode::Type() const { return "pose_assoc"; } bool PoseAssocNode::Init(const SimpleJson& config, const NodeContext& ctx) { id_ = config.ValueOr("id", "pose_assoc"); if (!ParseConfig(config, impl_->config, impl_->init_err)) { LogError("[pose_assoc] invalid config: " + impl_->init_err); return false; } output_queues_ = ctx.output_queues; return true; } bool PoseAssocNode::Start() { return true; } void PoseAssocNode::Stop() {} NodeStatus PoseAssocNode::Process(FramePtr frame) { if (!frame) return NodeStatus::DROP; if (frame->pose && frame->det) { for (auto& pose : frame->pose->items) { pose.track_id = -1; } std::vector candidates; candidates.reserve(frame->pose->items.size() * frame->det->items.size()); for (size_t pi = 0; pi < frame->pose->items.size(); ++pi) { const Rect& pose_bbox = frame->pose->items[pi].bbox; for (size_t di = 0; di < frame->det->items.size(); ++di) { const auto& det = frame->det->items[di]; if (det.track_id < 0) continue; const float iou = IoU(pose_bbox, det.bbox); if (iou >= impl_->config.min_iou) { candidates.push_back(MatchCandidate{iou, static_cast(pi), static_cast(di)}); } } } std::sort(candidates.begin(), candidates.end(), [](const MatchCandidate& lhs, const MatchCandidate& rhs) { if (lhs.iou != rhs.iou) return lhs.iou > rhs.iou; if (lhs.pose_index != rhs.pose_index) return lhs.pose_index < rhs.pose_index; return lhs.det_index < rhs.det_index; }); std::vector used_pose(frame->pose->items.size(), false); std::vector used_det(frame->det->items.size(), false); for (const auto& candidate : candidates) { if (used_pose[static_cast(candidate.pose_index)] || used_det[static_cast(candidate.det_index)]) { continue; } frame->pose->items[static_cast(candidate.pose_index)].track_id = frame->det->items[static_cast(candidate.det_index)].track_id; used_pose[static_cast(candidate.pose_index)] = true; used_det[static_cast(candidate.det_index)] = true; } } PushToDownstream(output_queues_, frame); return NodeStatus::OK; } #ifndef RK3588_TEST_BUILD REGISTER_NODE(PoseAssocNode, "pose_assoc"); #endif } // namespace rk3588