From 54a28ad5c062a1fe16caa635ccb64c1ffed914fe Mon Sep 17 00:00:00 2001 From: tian <11429339@qq.com> Date: Thu, 16 Apr 2026 10:41:26 +0800 Subject: [PATCH] Aggregate gallery search by person --- include/face/face_gallery_search.h | 87 ++++++++++++++++++++ plugins/ai_face_recog/ai_face_recog_node.cpp | 49 ++--------- tests/CMakeLists.txt | 1 + tests/test_face_gallery_search.cpp | 41 +++++++++ 4 files changed, 138 insertions(+), 40 deletions(-) create mode 100644 include/face/face_gallery_search.h create mode 100644 tests/test_face_gallery_search.cpp diff --git a/include/face/face_gallery_search.h b/include/face/face_gallery_search.h new file mode 100644 index 0000000..17b0816 --- /dev/null +++ b/include/face/face_gallery_search.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace rk3588 { + +struct GallerySearchEntry { + int person_id = -1; + std::string name; + std::vector emb; +}; + +struct GallerySearchResult { + int best_person_id = -1; + std::string best_name; + float best_sim = 0.0f; + float second_sim = 0.0f; +}; + +template +GallerySearchResult SearchTop2ByPersonGeneric(const std::vector& entries, + const std::vector& emb_normed, + PersonIdFn person_id_fn, + NameFn name_fn, + EmbFn emb_fn) { + GallerySearchResult r; + if (entries.empty() || emb_normed.empty()) return r; + + std::unordered_map> per_person_best; + per_person_best.reserve(entries.size()); + + for (const auto& entry : entries) { + const auto& entry_emb = emb_fn(entry); + if (entry_emb.size() != emb_normed.size()) continue; + + float sim = 0.0f; + for (size_t i = 0; i < emb_normed.size(); ++i) { + sim += emb_normed[i] * entry_emb[i]; + } + + const int person_id = person_id_fn(entry); + auto it = per_person_best.find(person_id); + if (it == per_person_best.end() || sim > it->second.second) { + per_person_best[person_id] = std::make_pair(name_fn(entry), sim); + } + } + + float best = -std::numeric_limits::infinity(); + float second = -std::numeric_limits::infinity(); + int best_person_id = -1; + std::string best_name; + + for (const auto& kv : per_person_best) { + const float sim = kv.second.second; + if (sim > best) { + second = best; + best = sim; + best_person_id = kv.first; + best_name = kv.second.first; + } else if (sim > second) { + second = sim; + } + } + + if (best_person_id >= 0) { + r.best_person_id = best_person_id; + r.best_name = best_name; + r.best_sim = best; + r.second_sim = std::isfinite(second) ? second : 0.0f; + } + return r; +} + +inline GallerySearchResult SearchTop2ByPerson(const std::vector& entries, + const std::vector& emb_normed) { + return SearchTop2ByPersonGeneric( + entries, emb_normed, + [](const GallerySearchEntry& entry) { return entry.person_id; }, + [](const GallerySearchEntry& entry) -> const std::string& { return entry.name; }, + [](const GallerySearchEntry& entry) -> const std::vector& { return entry.emb; }); +} + +} // namespace rk3588 diff --git a/plugins/ai_face_recog/ai_face_recog_node.cpp b/plugins/ai_face_recog/ai_face_recog_node.cpp index 8539313..2290e92 100644 --- a/plugins/ai_face_recog/ai_face_recog_node.cpp +++ b/plugins/ai_face_recog/ai_face_recog_node.cpp @@ -13,6 +13,7 @@ #include #include "hw/i_infer_backend.h" +#include "face/face_gallery_search.h" #include "face/face_recog_debug.h" #include "face/face_result.h" #include "node.h" @@ -391,49 +392,17 @@ public: int Dim() const { return dim_; } size_t Size() const { return entries_.size(); } - struct SearchResult { - int best_person_id = -1; - std::string best_name; - float best_sim = 0.0f; - float second_sim = 0.0f; - }; - - SearchResult SearchTop2(const std::vector& emb_normed) const { - SearchResult r; - if (entries_.empty() || dim_ <= 0) return r; - if (static_cast(emb_normed.size()) != dim_) return r; - - float best = -std::numeric_limits::infinity(); - float second = -std::numeric_limits::infinity(); - int best_idx = -1; - - for (size_t i = 0; i < entries_.size(); ++i) { - const float sim = Dot(emb_normed, entries_[i].emb); - if (sim > best) { - second = best; - best = sim; - best_idx = static_cast(i); - } else if (sim > second) { - second = sim; - } - } - - if (best_idx >= 0) { - r.best_person_id = entries_[static_cast(best_idx)].person_id; - r.best_name = entries_[static_cast(best_idx)].name; - r.best_sim = best; - r.second_sim = std::isfinite(second) ? second : 0.0f; - } - return r; + GallerySearchResult SearchTop2(const std::vector& emb_normed) const { + if (entries_.empty() || dim_ <= 0) return {}; + if (static_cast(emb_normed.size()) != dim_) return {}; + return SearchTop2ByPersonGeneric( + entries_, emb_normed, + [](const GalleryEntry& entry) { return entry.person_id; }, + [](const GalleryEntry& entry) -> const std::string& { return entry.name; }, + [](const GalleryEntry& entry) -> const std::vector& { return entry.emb; }); } private: - static float Dot(const std::vector& a, const std::vector& b) { - float s = 0.0f; - for (size_t i = 0; i < a.size(); ++i) s += a[i] * b[i]; - return s; - } - static void L2Normalize(std::vector& v) { double ss = 0.0; for (float x : v) ss += static_cast(x) * static_cast(x); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7820710..0379567 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -48,6 +48,7 @@ add_executable(rk3588_gtests test_alarm_behavior_events.cpp test_log_action.cpp test_face_recog_debug.cpp + test_face_gallery_search.cpp test_face_track_association.cpp test_face_track_alarm.cpp test_external_api_action.cpp diff --git a/tests/test_face_gallery_search.cpp b/tests/test_face_gallery_search.cpp new file mode 100644 index 0000000..2ab7bef --- /dev/null +++ b/tests/test_face_gallery_search.cpp @@ -0,0 +1,41 @@ +#include + +#include + +#include "face/face_gallery_search.h" + +namespace rk3588 { +namespace { + +TEST(FaceGallerySearchTest, AggregatesMultipleEmbeddingsOfSamePersonIntoSingleTopPerson) { + const std::vector entries = { + {1, "alice", {1.0f, 0.0f}}, + {1, "alice", {0.95f, 0.05f}}, + {2, "bob", {0.70f, 0.70f}}, + {3, "carol", {0.0f, 1.0f}}, + }; + + const GallerySearchResult result = SearchTop2ByPerson(entries, {1.0f, 0.0f}); + + EXPECT_EQ(result.best_person_id, 1); + EXPECT_EQ(result.best_name, "alice"); + EXPECT_FLOAT_EQ(result.best_sim, 1.0f); + EXPECT_NEAR(result.second_sim, 0.70f, 1e-5f); +} + +TEST(FaceGallerySearchTest, UsesDistinctSecondPersonInsteadOfSameIdentityDuplicate) { + const std::vector entries = { + {1, "alice", {0.90f, 0.10f}}, + {1, "alice", {0.89f, 0.11f}}, + {2, "bob", {0.60f, 0.40f}}, + }; + + const GallerySearchResult result = SearchTop2ByPerson(entries, {0.90f, 0.10f}); + + EXPECT_EQ(result.best_person_id, 1); + EXPECT_FLOAT_EQ(result.best_sim, 0.82f); + EXPECT_FLOAT_EQ(result.second_sim, 0.58f); +} + +} // namespace +} // namespace rk3588