OrangePi3588Media/plugins/tracker/tracker_node.cpp

721 lines
25 KiB
C++

#include <atomic>
#include <algorithm>
#include <chrono>
#include <cstdint>
#include <map>
#include <mutex>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "node.h"
#include "utils/logger.h"
#include "utils/shared_state.h"
namespace rk3588 {
namespace {
using Clock = std::chrono::steady_clock;
static inline uint64_t NowUsSteady() {
return static_cast<uint64_t>(
std::chrono::duration_cast<std::chrono::microseconds>(Clock::now().time_since_epoch()).count());
}
static inline float Area(const Rect& r) {
const float w = std::max(0.0f, r.w);
const float h = std::max(0.0f, r.h);
return w * h;
}
static inline float IoU(const Rect& a, const Rect& b) {
const float ax1 = a.x;
const float ay1 = a.y;
const float ax2 = a.x + a.w;
const float ay2 = a.y + a.h;
const float bx1 = b.x;
const float by1 = b.y;
const float bx2 = b.x + b.w;
const float by2 = b.y + b.h;
const float ix1 = std::max(ax1, bx1);
const float iy1 = std::max(ay1, by1);
const float ix2 = std::min(ax2, bx2);
const float iy2 = std::min(ay2, by2);
const float iw = std::max(0.0f, ix2 - ix1);
const float ih = std::max(0.0f, iy2 - iy1);
const float inter = iw * ih;
if (inter <= 0.0f) return 0.0f;
const float ua = std::max(0.0f, Area(a));
const float ub = std::max(0.0f, Area(b));
const float uni = ua + ub - inter;
if (uni <= 0.0f) return 0.0f;
return inter / uni;
}
struct ConfigSnapshot {
std::string id;
std::string mode; // off | bytetrack_lite
bool per_class = true;
// Internal graph-scoped key used to publish current tracked targets.
std::string state_key;
std::set<int> track_classes; // whitelist
std::set<int> ignore_classes; // blacklist
std::unordered_set<std::string> allowed_models;
float high_th = 0.5f;
float low_th = 0.1f;
float iou_th = 0.3f;
int64_t max_age_ms = 1500;
int min_hits = 2;
int max_tracks = 128;
bool stats_log = false;
int64_t stats_interval_ms = 200;
};
static bool ParseIntSet(const SimpleJson& arr, std::set<int>& out, std::string& err) {
if (!arr.IsArray()) {
err = "expected array";
return false;
}
out.clear();
for (const auto& it : arr.AsArray()) {
const int v = it.AsInt(-1);
if (v < 0) continue;
out.insert(v);
}
return true;
}
static bool ParseStrSet(const SimpleJson& arr, std::unordered_set<std::string>& out, std::string& err) {
if (!arr.IsArray()) {
err = "expected array";
return false;
}
out.clear();
for (const auto& it : arr.AsArray()) {
const std::string s = it.AsString("");
if (!s.empty()) out.insert(s);
}
return true;
}
static bool BuildConfigSnapshot(const SimpleJson& config, std::shared_ptr<const ConfigSnapshot>& out,
std::string& err) {
auto snap = std::make_shared<ConfigSnapshot>();
snap->id = config.ValueOr<std::string>("id", "tracker");
snap->mode = config.ValueOr<std::string>("mode", "bytetrack_lite");
snap->per_class = config.ValueOr<bool>("per_class", true);
if (const SimpleJson* tc = config.Find("track_classes")) {
if (!ParseIntSet(*tc, snap->track_classes, err)) {
err = "track_classes: " + err;
return false;
}
}
if (const SimpleJson* ic = config.Find("ignore_classes")) {
if (!ParseIntSet(*ic, snap->ignore_classes, err)) {
err = "ignore_classes: " + err;
return false;
}
}
if (const SimpleJson* am = config.Find("allowed_models")) {
if (!ParseStrSet(*am, snap->allowed_models, err)) {
err = "allowed_models: " + err;
return false;
}
}
snap->high_th = config.ValueOr<float>("high_th", 0.5f);
snap->low_th = config.ValueOr<float>("low_th", 0.1f);
snap->iou_th = config.ValueOr<float>("iou_th", 0.3f);
snap->max_age_ms = static_cast<int64_t>(config.ValueOr<int>("max_age_ms", 1500));
snap->min_hits = config.ValueOr<int>("min_hits", 2);
snap->max_tracks = config.ValueOr<int>("max_tracks", 128);
if (snap->high_th < 0.0f || snap->high_th > 1.0f) {
err = "high_th must be in [0,1]";
return false;
}
if (snap->low_th < 0.0f || snap->low_th > 1.0f) {
err = "low_th must be in [0,1]";
return false;
}
if (snap->low_th > snap->high_th) {
err = "low_th must be <= high_th";
return false;
}
if (snap->iou_th < 0.0f || snap->iou_th > 1.0f) {
err = "iou_th must be in [0,1]";
return false;
}
if (snap->max_age_ms < 0) snap->max_age_ms = 0;
if (snap->min_hits < 1) snap->min_hits = 1;
if (snap->max_tracks < 1) snap->max_tracks = 1;
if (const SimpleJson* dbg = config.Find("debug")) {
if (!dbg->IsObject()) {
err = "debug must be object";
return false;
}
snap->stats_log = dbg->ValueOr<bool>("stats", false);
snap->stats_interval_ms = static_cast<int64_t>(dbg->ValueOr<int>("stats_interval", 200));
if (snap->stats_interval_ms < 1) snap->stats_interval_ms = 1;
}
if (snap->mode != "off" && snap->mode != "bytetrack_lite") {
err = "mode must be 'off' or 'bytetrack_lite'";
return false;
}
out = std::move(snap);
return true;
}
struct Track {
int id = -1;
int cls_id = -1;
Rect bbox{};
uint64_t last_seen_us = 0;
uint64_t last_seen_frame_id = 0;
int hit_streak = 0;
int total_hits = 0;
bool confirmed = false;
};
struct MatchPair {
float iou = 0.0f;
int track_idx = -1;
int det_idx = -1;
};
static void GreedyMatch(const std::vector<Track>& tracks, const std::vector<int>& track_indices,
const std::vector<Detection>& dets, const std::vector<int>& det_indices, float iou_th,
std::vector<std::pair<int, int>>& out_matches, std::vector<int>& out_unmatched_tracks,
std::vector<int>& out_unmatched_dets) {
out_matches.clear();
out_unmatched_tracks = track_indices;
out_unmatched_dets = det_indices;
if (track_indices.empty() || det_indices.empty()) return;
std::vector<MatchPair> pairs;
pairs.reserve(track_indices.size() * det_indices.size());
for (int ti : track_indices) {
const auto& trk = tracks[static_cast<size_t>(ti)];
for (int di : det_indices) {
const float iou = IoU(trk.bbox, dets[static_cast<size_t>(di)].bbox);
if (iou >= iou_th) {
pairs.push_back(MatchPair{iou, ti, di});
}
}
}
if (pairs.empty()) return;
std::sort(pairs.begin(), pairs.end(), [](const MatchPair& a, const MatchPair& b) {
if (a.iou != b.iou) return a.iou > b.iou;
if (a.track_idx != b.track_idx) return a.track_idx < b.track_idx;
return a.det_idx < b.det_idx;
});
std::unordered_set<int> used_tracks;
std::unordered_set<int> used_dets;
used_tracks.reserve(track_indices.size());
used_dets.reserve(det_indices.size());
for (const auto& p : pairs) {
if (used_tracks.count(p.track_idx) || used_dets.count(p.det_idx)) continue;
used_tracks.insert(p.track_idx);
used_dets.insert(p.det_idx);
out_matches.emplace_back(p.track_idx, p.det_idx);
}
out_unmatched_tracks.clear();
out_unmatched_tracks.reserve(track_indices.size());
for (int ti : track_indices) {
if (!used_tracks.count(ti)) out_unmatched_tracks.push_back(ti);
}
out_unmatched_dets.clear();
out_unmatched_dets.reserve(det_indices.size());
for (int di : det_indices) {
if (!used_dets.count(di)) out_unmatched_dets.push_back(di);
}
}
} // namespace
class TrackerNode final : public INode {
public:
std::string Id() const override { return id_; }
std::string Type() const override { return "tracker"; }
bool Init(const SimpleJson& config, const NodeContext& ctx) override {
std::shared_ptr<const ConfigSnapshot> snap;
std::string err;
if (!BuildConfigSnapshot(config, snap, err)) {
LogError("[tracker] invalid config: " + err);
return false;
}
id_ = snap->id;
std::const_pointer_cast<ConfigSnapshot>(snap)->state_key = DefaultTrackedTargetsKey(ctx.graph_name);
input_queue_ = ctx.input_queue;
if (!input_queue_) {
LogError("[tracker] no input queue for node " + id_);
return false;
}
if (ctx.output_queues.empty()) {
LogError("[tracker] no output queue for node " + id_);
return false;
}
output_queues_ = ctx.output_queues;
{
std::lock_guard<std::mutex> lk(mu_);
cfg_ = std::move(snap);
tracks_.clear();
next_track_id_ = 0;
last_stats_us_ = 0;
}
return true;
}
bool Start() override {
std::string state_key;
{
std::lock_guard<std::mutex> lk(mu_);
state_key = cfg_ ? cfg_->state_key : std::string{};
}
LogInfo("[tracker] started id=" + id_ + " shared_target_key=" + state_key);
return true;
}
void Stop() override {
LogInfo("[tracker] stopped id=" + id_);
}
bool UpdateConfig(const SimpleJson& new_config) override {
std::shared_ptr<const ConfigSnapshot> snap;
std::string err;
if (!BuildConfigSnapshot(new_config, snap, err)) {
LogWarn("[tracker] UpdateConfig rejected: " + err);
return false;
}
if (!id_.empty() && !snap->id.empty() && snap->id != id_) {
return false;
}
bool should_reset = false;
{
std::lock_guard<std::mutex> lk(mu_);
if (cfg_) {
std::const_pointer_cast<ConfigSnapshot>(snap)->state_key = cfg_->state_key;
should_reset = (cfg_->mode != snap->mode) || (cfg_->per_class != snap->per_class) ||
(cfg_->track_classes != snap->track_classes) ||
(cfg_->ignore_classes != snap->ignore_classes) ||
(cfg_->allowed_models != snap->allowed_models) || (cfg_->high_th != snap->high_th) ||
(cfg_->low_th != snap->low_th) || (cfg_->iou_th != snap->iou_th) ||
(cfg_->max_age_ms != snap->max_age_ms) || (cfg_->min_hits != snap->min_hits) ||
(cfg_->max_tracks != snap->max_tracks);
}
cfg_ = std::move(snap);
if (should_reset) {
tracks_.clear();
}
}
return true;
}
bool GetCustomMetrics(SimpleJson& out) const override {
SimpleJson::Object o;
o["tracks_active"] = SimpleJson(static_cast<double>(tracks_active_.load()));
o["tracks_created_total"] = SimpleJson(static_cast<double>(tracks_created_total_.load()));
o["tracks_removed_total"] = SimpleJson(static_cast<double>(tracks_removed_total_.load()));
o["matched_total"] = SimpleJson(static_cast<double>(matched_total_.load()));
o["unmatched_dets_total"] = SimpleJson(static_cast<double>(unmatched_dets_total_.load()));
const uint64_t frames = processed_frames_.load();
const uint64_t total_us = total_process_us_.load();
const double avg_ms = frames > 0 ? (static_cast<double>(total_us) / 1000.0 / static_cast<double>(frames)) : 0.0;
o["avg_process_time_ms"] = SimpleJson(avg_ms);
out = SimpleJson(std::move(o));
return true;
}
NodeStatus Process(FramePtr frame) override {
if (!frame) return NodeStatus::DROP;
const uint64_t t0 = NowUsSteady();
std::shared_ptr<const ConfigSnapshot> cfg;
{
std::lock_guard<std::mutex> lk(mu_);
cfg = cfg_;
}
if (!cfg || cfg->mode == "off") {
PushToDownstream(frame);
return NodeStatus::OK;
}
uint64_t now_us = 0;
uint64_t removed_prune = 0;
{
std::lock_guard<std::mutex> lk(mu_);
now_us = ResolveNowUsLocked(*frame);
removed_prune = PruneExpiredLocked(now_us, *cfg);
}
if (removed_prune) tracks_removed_total_.fetch_add(removed_prune);
if (!frame->det || frame->det->items.empty()) {
MaybeLogStats(now_us, *cfg);
MaybeUpdateSharedState(*cfg, frame->det.get());
PushToDownstream(frame);
total_process_us_.fetch_add(NowUsSteady() - t0);
processed_frames_.fetch_add(1);
return NodeStatus::OK;
}
if (!cfg->allowed_models.empty()) {
if (cfg->allowed_models.find(frame->det->model_name) == cfg->allowed_models.end()) {
// Not an enabled model: pass-through (do not rewrite track_id).
MaybeLogStats(now_us, *cfg);
MaybeUpdateSharedState(*cfg, frame->det.get());
PushToDownstream(frame);
total_process_us_.fetch_add(NowUsSteady() - t0);
processed_frames_.fetch_add(1);
return NodeStatus::OK;
}
}
auto& dets = frame->det->items;
// Collect det indices.
std::vector<int> high_dets;
std::vector<int> low_dets;
high_dets.reserve(dets.size());
low_dets.reserve(dets.size());
for (size_t i = 0; i < dets.size(); ++i) {
auto& d = dets[i];
d.track_id = -1;
if (!IsTrackClass(*cfg, d.cls_id)) {
d.track_id = -1;
continue;
}
if (d.score >= cfg->high_th) {
high_dets.push_back(static_cast<int>(i));
} else if (d.score >= cfg->low_th) {
low_dets.push_back(static_cast<int>(i));
}
}
if (high_dets.empty() && low_dets.empty()) {
MaybeLogStats(now_us, *cfg);
MaybeUpdateSharedState(*cfg, frame->det.get());
PushToDownstream(frame);
total_process_us_.fetch_add(NowUsSteady() - t0);
processed_frames_.fetch_add(1);
return NodeStatus::OK;
}
uint64_t matched_local = 0;
uint64_t unmatched_dets_local = 0;
uint64_t created_local = 0;
{
std::lock_guard<std::mutex> lk(mu_);
// Build track indices by group.
std::vector<int> all_track_indices;
all_track_indices.reserve(tracks_.size());
for (size_t i = 0; i < tracks_.size(); ++i) all_track_indices.push_back(static_cast<int>(i));
auto group_key_track = [&](const Track& t) -> int { return cfg->per_class ? t.cls_id : 0; };
auto group_key_det = [&](const Detection& d) -> int { return cfg->per_class ? d.cls_id : 0; };
// Group tracks.
std::map<int, std::vector<int>> tracks_by_group;
for (int ti : all_track_indices) {
const auto& trk = tracks_[static_cast<size_t>(ti)];
tracks_by_group[group_key_track(trk)].push_back(ti);
}
// Group dets.
std::map<int, std::vector<int>> high_by_group;
std::map<int, std::vector<int>> low_by_group;
for (int di : high_dets) {
high_by_group[group_key_det(dets[static_cast<size_t>(di)])].push_back(di);
}
for (int di : low_dets) {
low_by_group[group_key_det(dets[static_cast<size_t>(di)])].push_back(di);
}
// Stage 1: high det match.
std::vector<std::pair<int, int>> matches;
std::vector<int> un_tracks;
std::vector<int> un_dets;
std::vector<int> stage1_unmatched_tracks;
stage1_unmatched_tracks.reserve(tracks_.size());
for (auto& kv : tracks_by_group) {
const int g = kv.first;
const auto& t_idx = kv.second;
auto itD = high_by_group.find(g);
const std::vector<int> empty;
const auto& d_idx = (itD != high_by_group.end()) ? itD->second : empty;
GreedyMatch(tracks_, t_idx, dets, d_idx, cfg->iou_th, matches, un_tracks, un_dets);
// Apply matches
for (const auto& m : matches) {
Track& trk = tracks_[static_cast<size_t>(m.first)];
Detection& det = dets[static_cast<size_t>(m.second)];
UpdateTrackLocked(trk, det, now_us, frame->frame_id, *cfg);
if (trk.confirmed) det.track_id = trk.id;
++matched_local;
}
// Accumulate unmatched tracks for stage2.
for (int ti : un_tracks) {
stage1_unmatched_tracks.push_back(ti);
}
// We'll create new tracks from high det unmatched below.
high_by_group[g] = un_dets;
}
// Stage 2: low det match for unmatched tracks.
// Re-group stage1 unmatched tracks.
std::map<int, std::vector<int>> stage1_un_tracks_by_group;
for (int ti : stage1_unmatched_tracks) {
const auto& trk = tracks_[static_cast<size_t>(ti)];
stage1_un_tracks_by_group[group_key_track(trk)].push_back(ti);
}
for (auto& kv : stage1_un_tracks_by_group) {
const int g = kv.first;
const auto& t_idx = kv.second;
auto itD = low_by_group.find(g);
if (itD == low_by_group.end()) continue;
auto& d_idx = itD->second;
if (d_idx.empty()) continue;
GreedyMatch(tracks_, t_idx, dets, d_idx, cfg->iou_th, matches, un_tracks, un_dets);
for (const auto& m : matches) {
Track& trk = tracks_[static_cast<size_t>(m.first)];
Detection& det = dets[static_cast<size_t>(m.second)];
UpdateTrackLocked(trk, det, now_us, frame->frame_id, *cfg);
if (trk.confirmed) det.track_id = trk.id;
++matched_local;
}
d_idx = un_dets;
}
// Create new tracks from remaining unmatched HIGH dets.
for (auto& kv : high_by_group) {
auto& remain_high = kv.second;
for (int di : remain_high) {
if (static_cast<int>(tracks_.size()) >= cfg->max_tracks) {
// Capacity reached.
++unmatched_dets_local;
continue;
}
Detection& det = dets[static_cast<size_t>(di)];
Track trk;
trk.id = next_track_id_++;
trk.cls_id = det.cls_id;
trk.bbox = det.bbox;
trk.last_seen_us = now_us;
trk.last_seen_frame_id = frame->frame_id;
trk.hit_streak = 1;
trk.total_hits = 1;
trk.confirmed = (cfg->min_hits <= 1);
if (trk.confirmed) det.track_id = trk.id;
tracks_.push_back(std::move(trk));
++created_local;
}
}
// Unmatched low dets are just counted.
for (const auto& kv : low_by_group) {
unmatched_dets_local += kv.second.size();
}
// Tracks not updated this frame are left as-is; they will be removed by time-based pruning.
tracks_active_.store(tracks_.size());
}
if (created_local) tracks_created_total_.fetch_add(created_local);
if (matched_local) matched_total_.fetch_add(matched_local);
if (unmatched_dets_local) unmatched_dets_total_.fetch_add(unmatched_dets_local);
MaybeLogStats(now_us, *cfg);
MaybeUpdateSharedState(*cfg, frame->det.get());
PushToDownstream(frame);
total_process_us_.fetch_add(NowUsSteady() - t0);
processed_frames_.fetch_add(1);
return NodeStatus::OK;
}
private:
void MaybeUpdateSharedState(const ConfigSnapshot& cfg, const DetectionResult* det_hint) {
if (cfg.state_key.empty()) return;
TargetsSnapshot snap;
snap.update_steady_us = NowSteadyUs();
if (det_hint) {
snap.img_w = det_hint->img_w;
snap.img_h = det_hint->img_h;
snap.model_name = det_hint->model_name;
}
{
std::lock_guard<std::mutex> lk(mu_);
snap.objects.reserve(tracks_.size());
for (const auto& t : tracks_) {
TrackedObject obj;
obj.cls_id = t.cls_id;
obj.track_id = t.id;
obj.score = 0.0f;
obj.bbox = t.bbox;
obj.confirmed = t.confirmed;
snap.objects.push_back(std::move(obj));
}
}
SharedState::Instance().SetTargets(cfg.state_key, std::move(snap));
}
uint64_t ResolveNowUsLocked(const Frame& frame) {
uint64_t t = frame.pts;
if (t == 0) {
t = (last_time_us_ == 0) ? NowUsSteady() : (last_time_us_ + 1);
} else if (last_time_us_ != 0) {
// Clamp non-monotonic timestamps to avoid sudden large backward jumps.
if (t + 2000000ULL < last_time_us_) {
t = last_time_us_ + 1;
} else if (t < last_time_us_) {
t = last_time_us_;
}
}
last_time_us_ = t;
return t;
}
static bool IsTrackClass(const ConfigSnapshot& cfg, int cls_id) {
if (cls_id < 0) return false;
if (!cfg.track_classes.empty()) {
return cfg.track_classes.count(cls_id) > 0;
}
if (!cfg.ignore_classes.empty()) {
return cfg.ignore_classes.count(cls_id) == 0;
}
return true;
}
static void UpdateTrackLocked(Track& trk, const Detection& det, uint64_t now_us, uint64_t frame_id,
const ConfigSnapshot& cfg) {
trk.bbox = det.bbox;
trk.last_seen_us = now_us;
trk.last_seen_frame_id = frame_id;
trk.hit_streak += 1;
trk.total_hits += 1;
if (!trk.confirmed && trk.hit_streak >= cfg.min_hits) {
trk.confirmed = true;
}
if (!cfg.per_class) {
trk.cls_id = det.cls_id;
}
}
uint64_t PruneExpiredLocked(uint64_t now_us, const ConfigSnapshot& cfg) {
if (tracks_.empty()) return 0;
const uint64_t max_age_us = static_cast<uint64_t>(std::max<int64_t>(0, cfg.max_age_ms)) * 1000ULL;
const size_t before = tracks_.size();
if (max_age_us == 0) {
tracks_.clear();
tracks_active_.store(0);
return static_cast<uint64_t>(before);
}
tracks_.erase(std::remove_if(tracks_.begin(), tracks_.end(), [&](const Track& t) {
if (t.last_seen_us == 0) return true;
return (now_us > t.last_seen_us) && ((now_us - t.last_seen_us) > max_age_us);
}),
tracks_.end());
const size_t after = tracks_.size();
tracks_active_.store(after);
return static_cast<uint64_t>(before - after);
}
void MaybeLogStats(uint64_t now_us, const ConfigSnapshot& cfg) {
if (!cfg.stats_log) return;
uint64_t last = 0;
{
std::lock_guard<std::mutex> lk(mu_);
last = last_stats_us_;
if (last_stats_us_ == 0 || (now_us > last_stats_us_ && (now_us - last_stats_us_) >=
static_cast<uint64_t>(cfg.stats_interval_ms) * 1000ULL)) {
last_stats_us_ = now_us;
} else {
return;
}
}
(void)last;
LogInfo("[tracker] id=" + id_ +
" tracks=" + std::to_string(tracks_active_.load()) +
" created=" + std::to_string(tracks_created_total_.load()) +
" removed=" + std::to_string(tracks_removed_total_.load()) +
" matched=" + std::to_string(matched_total_.load()) +
" unmatch_det=" + std::to_string(unmatched_dets_total_.load()));
}
void PushToDownstream(FramePtr frame) {
for (auto& q : output_queues_) {
q->Push(frame);
}
}
std::string id_;
std::shared_ptr<SpscQueue<FramePtr>> input_queue_;
std::vector<std::shared_ptr<SpscQueue<FramePtr>>> output_queues_;
mutable std::mutex mu_;
std::shared_ptr<const ConfigSnapshot> cfg_;
std::vector<Track> tracks_;
int next_track_id_ = 0;
uint64_t last_time_us_ = 0;
uint64_t last_stats_us_ = 0;
std::atomic<uint64_t> tracks_active_{0};
std::atomic<uint64_t> tracks_created_total_{0};
std::atomic<uint64_t> tracks_removed_total_{0};
std::atomic<uint64_t> matched_total_{0};
std::atomic<uint64_t> unmatched_dets_total_{0};
std::atomic<uint64_t> processed_frames_{0};
std::atomic<uint64_t> total_process_us_{0};
};
REGISTER_NODE(TrackerNode, "tracker");
} // namespace rk3588