OrangePi3588Media/plugins/pose_assoc/pose_assoc_node.cpp

140 lines
4.3 KiB
C++

#include "pose_assoc_node.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "pose/pose_result.h"
#include "utils/logger.h"
namespace rk3588 {
namespace {
struct PoseAssocConfig {
float min_iou = 0.1f;
};
struct MatchCandidate {
float iou = 0.0f;
int pose_index = -1;
int det_index = -1;
};
bool ParseConfig(const SimpleJson& config, PoseAssocConfig& out, std::string& err) {
out.min_iou = config.ValueOr<float>("min_iou", 0.1f);
if (out.min_iou < 0.0f || out.min_iou > 1.0f) {
err = "min_iou must be in [0, 1]";
return false;
}
return true;
}
float IoU(const Rect& lhs, const Rect& rhs) {
const float x1 = std::max(lhs.x, rhs.x);
const float y1 = std::max(lhs.y, rhs.y);
const float x2 = std::min(lhs.x + lhs.w, rhs.x + rhs.w);
const float y2 = std::min(lhs.y + lhs.h, rhs.y + rhs.h);
const float iw = std::max(0.0f, x2 - x1);
const float ih = std::max(0.0f, y2 - y1);
const float inter = iw * ih;
const float area_l = std::max(0.0f, lhs.w) * std::max(0.0f, lhs.h);
const float area_r = std::max(0.0f, rhs.w) * std::max(0.0f, rhs.h);
const float uni = area_l + area_r - inter;
return uni <= 0.0f ? 0.0f : (inter / uni);
}
void PushToDownstream(const std::vector<std::shared_ptr<SpscQueue<FramePtr>>>& output_queues, const FramePtr& frame) {
for (const auto& q : output_queues) {
if (q) q->Push(frame);
}
}
} // namespace
struct PoseAssocNode::Impl {
PoseAssocConfig config;
std::string init_err;
};
PoseAssocNode::PoseAssocNode() : impl_(std::make_unique<Impl>()) {}
PoseAssocNode::~PoseAssocNode() = default;
std::string PoseAssocNode::Id() const {
return id_;
}
std::string PoseAssocNode::Type() const {
return "pose_assoc";
}
bool PoseAssocNode::Init(const SimpleJson& config, const NodeContext& ctx) {
id_ = config.ValueOr<std::string>("id", "pose_assoc");
if (!ParseConfig(config, impl_->config, impl_->init_err)) {
LogError("[pose_assoc] invalid config: " + impl_->init_err);
return false;
}
output_queues_ = ctx.output_queues;
return true;
}
bool PoseAssocNode::Start() {
return true;
}
void PoseAssocNode::Stop() {}
NodeStatus PoseAssocNode::Process(FramePtr frame) {
if (!frame) return NodeStatus::DROP;
if (frame->pose && frame->det) {
for (auto& pose : frame->pose->items) {
pose.track_id = -1;
}
std::vector<MatchCandidate> candidates;
candidates.reserve(frame->pose->items.size() * frame->det->items.size());
for (size_t pi = 0; pi < frame->pose->items.size(); ++pi) {
const Rect& pose_bbox = frame->pose->items[pi].bbox;
for (size_t di = 0; di < frame->det->items.size(); ++di) {
const auto& det = frame->det->items[di];
if (det.track_id < 0) continue;
const float iou = IoU(pose_bbox, det.bbox);
if (iou >= impl_->config.min_iou) {
candidates.push_back(MatchCandidate{iou, static_cast<int>(pi), static_cast<int>(di)});
}
}
}
std::sort(candidates.begin(), candidates.end(), [](const MatchCandidate& lhs, const MatchCandidate& rhs) {
if (lhs.iou != rhs.iou) return lhs.iou > rhs.iou;
if (lhs.pose_index != rhs.pose_index) return lhs.pose_index < rhs.pose_index;
return lhs.det_index < rhs.det_index;
});
std::vector<bool> used_pose(frame->pose->items.size(), false);
std::vector<bool> used_det(frame->det->items.size(), false);
for (const auto& candidate : candidates) {
if (used_pose[static_cast<size_t>(candidate.pose_index)] ||
used_det[static_cast<size_t>(candidate.det_index)]) {
continue;
}
frame->pose->items[static_cast<size_t>(candidate.pose_index)].track_id =
frame->det->items[static_cast<size_t>(candidate.det_index)].track_id;
used_pose[static_cast<size_t>(candidate.pose_index)] = true;
used_det[static_cast<size_t>(candidate.det_index)] = true;
}
}
PushToDownstream(output_queues_, frame);
return NodeStatus::OK;
}
#ifndef RK3588_TEST_BUILD
REGISTER_NODE(PoseAssocNode, "pose_assoc");
#endif
} // namespace rk3588