diff --git a/configs/full_pipeline_1080p_test_alarm.json b/configs/full_pipeline_1080p_test_alarm.json index f41013d..5dcf400 100644 --- a/configs/full_pipeline_1080p_test_alarm.json +++ b/configs/full_pipeline_1080p_test_alarm.json @@ -237,8 +237,8 @@ "max_shoe_roi_height_ratio": 0.6, "max_shoe_roi_area_ratio": 0.25, "min_front_shoe_width_ratio": 0.18, - "max_front_shoe_aspect_ratio": 1.5, - "max_side_height_width_ratio": 1.1, + "max_front_shoe_aspect_ratio": 2.0, + "max_side_height_width_ratio": 1.2, "max_shoe_aspect_ratio": 2.0 } }, diff --git a/plugins/logic_gate/logic_gate_node.cpp b/plugins/logic_gate/logic_gate_node.cpp index 2d028cb..521f78a 100644 --- a/plugins/logic_gate/logic_gate_node.cpp +++ b/plugins/logic_gate/logic_gate_node.cpp @@ -6,6 +6,7 @@ #include #include "color_analyzer.h" +#include "person_shoe_filter.h" #include "person_shoe_shape.h" #include "frame/frame.h" #include "node.h" @@ -335,6 +336,8 @@ private: std::vector shoe_used(shoe_indices.size(), false); std::vector appended; + std::vector selected_shoe_item_indices; + selected_shoe_item_indices.reserve(shoe_indices.size()); for (size_t pi = 0; pi < person_indices.size(); ++pi) { Detection& person = items[person_indices[pi]]; @@ -445,7 +448,9 @@ private: if (best_shoe_local >= 0) { shoe_used[static_cast(best_shoe_local)] = true; - Detection& matched_shoe = items[shoe_indices[static_cast(best_shoe_local)]]; + const size_t matched_item_index = shoe_indices[static_cast(best_shoe_local)]; + Detection& matched_shoe = items[matched_item_index]; + selected_shoe_item_indices.push_back(matched_item_index); if (config_.debug) { LogInfo("[logic_gate] shoe selected person_track=" + std::to_string(person.track_id) + " candidate_count=" + std::to_string(candidate_count) + @@ -493,6 +498,10 @@ private: } } + items = FilterPersonShoeOutput(items, + config_.person_shoe.shoe_class, + selected_shoe_item_indices); + if (!appended.empty()) { items.insert(items.end(), appended.begin(), appended.end()); } diff --git a/plugins/logic_gate/person_shoe_filter.h b/plugins/logic_gate/person_shoe_filter.h new file mode 100644 index 0000000..8e7cd3a --- /dev/null +++ b/plugins/logic_gate/person_shoe_filter.h @@ -0,0 +1,27 @@ +#pragma once + +#include +#include +#include + +#include "frame/frame.h" + +namespace rk3588 { + +inline std::vector FilterPersonShoeOutput(const std::vector& items, + int shoe_class, + const std::vector& selected_shoe_indices) { + const std::unordered_set selected(selected_shoe_indices.begin(), selected_shoe_indices.end()); + + std::vector filtered; + filtered.reserve(items.size()); + for (size_t i = 0; i < items.size(); ++i) { + const Detection& det = items[i]; + if (det.cls_id != shoe_class || selected.count(i) > 0) { + filtered.push_back(det); + } + } + return filtered; +} + +} // namespace rk3588 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 94dc850..6ba9bc3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -39,6 +39,7 @@ add_executable(rk3588_gtests test_frame_buffer.cpp test_behavior_event_model.cpp test_person_shoe_shape.cpp + test_person_shoe_filter.cpp test_region_event.cpp test_action_recog.cpp test_event_fusion.cpp diff --git a/tests/test_person_shoe_filter.cpp b/tests/test_person_shoe_filter.cpp new file mode 100644 index 0000000..322e5ee --- /dev/null +++ b/tests/test_person_shoe_filter.cpp @@ -0,0 +1,52 @@ +#include + +#include + +#include "../plugins/logic_gate/person_shoe_filter.h" + +namespace rk3588 { +namespace { + +Detection MakeDetection(int cls_id, int track_id, float x) { + Detection det; + det.cls_id = cls_id; + det.track_id = track_id; + det.score = 0.9f; + det.bbox = Rect{x, 10.0f, 20.0f, 20.0f}; + return det; +} + +TEST(PersonShoeFilterTest, KeepsOnlySelectedShoesAndAllNonShoes) { + std::vector items; + items.push_back(MakeDetection(0, 101, 10.0f)); + items.push_back(MakeDetection(1, -1, 20.0f)); + items.push_back(MakeDetection(1, 101, 30.0f)); + items.push_back(MakeDetection(2, 101, 40.0f)); + items.push_back(MakeDetection(1, -1, 50.0f)); + + const std::vector selected_shoe_indices = {2}; + const auto filtered = FilterPersonShoeOutput(items, 1, selected_shoe_indices); + + ASSERT_EQ(filtered.size(), 3u); + EXPECT_EQ(filtered[0].cls_id, 0); + EXPECT_EQ(filtered[1].cls_id, 1); + EXPECT_EQ(filtered[1].track_id, 101); + EXPECT_FLOAT_EQ(filtered[1].bbox.x, 30.0f); + EXPECT_EQ(filtered[2].cls_id, 2); +} + +TEST(PersonShoeFilterTest, DropsAllShoesWhenNothingIsSelected) { + std::vector items; + items.push_back(MakeDetection(1, -1, 20.0f)); + items.push_back(MakeDetection(0, 101, 30.0f)); + items.push_back(MakeDetection(1, -1, 40.0f)); + + const auto filtered = FilterPersonShoeOutput(items, 1, {}); + + ASSERT_EQ(filtered.size(), 1u); + EXPECT_EQ(filtered[0].cls_id, 0); + EXPECT_EQ(filtered[0].track_id, 101); +} + +} // namespace +} // namespace rk3588