140 lines
4.3 KiB
C++
140 lines
4.3 KiB
C++
#include "pose_assoc_node.h"
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "pose/pose_result.h"
|
|
#include "utils/logger.h"
|
|
|
|
namespace rk3588 {
|
|
namespace {
|
|
|
|
struct PoseAssocConfig {
|
|
float min_iou = 0.1f;
|
|
};
|
|
|
|
struct MatchCandidate {
|
|
float iou = 0.0f;
|
|
int pose_index = -1;
|
|
int det_index = -1;
|
|
};
|
|
|
|
bool ParseConfig(const SimpleJson& config, PoseAssocConfig& out, std::string& err) {
|
|
out.min_iou = config.ValueOr<float>("min_iou", 0.1f);
|
|
if (out.min_iou < 0.0f || out.min_iou > 1.0f) {
|
|
err = "min_iou must be in [0, 1]";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
float IoU(const Rect& lhs, const Rect& rhs) {
|
|
const float x1 = std::max(lhs.x, rhs.x);
|
|
const float y1 = std::max(lhs.y, rhs.y);
|
|
const float x2 = std::min(lhs.x + lhs.w, rhs.x + rhs.w);
|
|
const float y2 = std::min(lhs.y + lhs.h, rhs.y + rhs.h);
|
|
const float iw = std::max(0.0f, x2 - x1);
|
|
const float ih = std::max(0.0f, y2 - y1);
|
|
const float inter = iw * ih;
|
|
const float area_l = std::max(0.0f, lhs.w) * std::max(0.0f, lhs.h);
|
|
const float area_r = std::max(0.0f, rhs.w) * std::max(0.0f, rhs.h);
|
|
const float uni = area_l + area_r - inter;
|
|
return uni <= 0.0f ? 0.0f : (inter / uni);
|
|
}
|
|
|
|
void PushToDownstream(const std::vector<std::shared_ptr<SpscQueue<FramePtr>>>& output_queues, const FramePtr& frame) {
|
|
for (const auto& q : output_queues) {
|
|
if (q) q->Push(frame);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
struct PoseAssocNode::Impl {
|
|
PoseAssocConfig config;
|
|
std::string init_err;
|
|
};
|
|
|
|
PoseAssocNode::PoseAssocNode() : impl_(std::make_unique<Impl>()) {}
|
|
|
|
PoseAssocNode::~PoseAssocNode() = default;
|
|
|
|
std::string PoseAssocNode::Id() const {
|
|
return id_;
|
|
}
|
|
|
|
std::string PoseAssocNode::Type() const {
|
|
return "pose_assoc";
|
|
}
|
|
|
|
bool PoseAssocNode::Init(const SimpleJson& config, const NodeContext& ctx) {
|
|
id_ = config.ValueOr<std::string>("id", "pose_assoc");
|
|
if (!ParseConfig(config, impl_->config, impl_->init_err)) {
|
|
LogError("[pose_assoc] invalid config: " + impl_->init_err);
|
|
return false;
|
|
}
|
|
output_queues_ = ctx.output_queues;
|
|
return true;
|
|
}
|
|
|
|
bool PoseAssocNode::Start() {
|
|
return true;
|
|
}
|
|
|
|
void PoseAssocNode::Stop() {}
|
|
|
|
NodeStatus PoseAssocNode::Process(FramePtr frame) {
|
|
if (!frame) return NodeStatus::DROP;
|
|
|
|
if (frame->pose && frame->det) {
|
|
for (auto& pose : frame->pose->items) {
|
|
pose.track_id = -1;
|
|
}
|
|
|
|
std::vector<MatchCandidate> candidates;
|
|
candidates.reserve(frame->pose->items.size() * frame->det->items.size());
|
|
for (size_t pi = 0; pi < frame->pose->items.size(); ++pi) {
|
|
const Rect& pose_bbox = frame->pose->items[pi].bbox;
|
|
for (size_t di = 0; di < frame->det->items.size(); ++di) {
|
|
const auto& det = frame->det->items[di];
|
|
if (det.track_id < 0) continue;
|
|
const float iou = IoU(pose_bbox, det.bbox);
|
|
if (iou >= impl_->config.min_iou) {
|
|
candidates.push_back(MatchCandidate{iou, static_cast<int>(pi), static_cast<int>(di)});
|
|
}
|
|
}
|
|
}
|
|
|
|
std::sort(candidates.begin(), candidates.end(), [](const MatchCandidate& lhs, const MatchCandidate& rhs) {
|
|
if (lhs.iou != rhs.iou) return lhs.iou > rhs.iou;
|
|
if (lhs.pose_index != rhs.pose_index) return lhs.pose_index < rhs.pose_index;
|
|
return lhs.det_index < rhs.det_index;
|
|
});
|
|
|
|
std::vector<bool> used_pose(frame->pose->items.size(), false);
|
|
std::vector<bool> used_det(frame->det->items.size(), false);
|
|
for (const auto& candidate : candidates) {
|
|
if (used_pose[static_cast<size_t>(candidate.pose_index)] ||
|
|
used_det[static_cast<size_t>(candidate.det_index)]) {
|
|
continue;
|
|
}
|
|
frame->pose->items[static_cast<size_t>(candidate.pose_index)].track_id =
|
|
frame->det->items[static_cast<size_t>(candidate.det_index)].track_id;
|
|
used_pose[static_cast<size_t>(candidate.pose_index)] = true;
|
|
used_det[static_cast<size_t>(candidate.det_index)] = true;
|
|
}
|
|
}
|
|
|
|
PushToDownstream(output_queues_, frame);
|
|
return NodeStatus::OK;
|
|
}
|
|
|
|
#ifndef RK3588_TEST_BUILD
|
|
REGISTER_NODE(PoseAssocNode, "pose_assoc");
|
|
#endif
|
|
|
|
} // namespace rk3588
|