193 lines
6.3 KiB
C++
193 lines
6.3 KiB
C++
#include <map>
|
|
#include <vector>
|
|
#include <memory>
|
|
#include <chrono>
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
|
|
#include "node.h"
|
|
#include "frame/frame.h"
|
|
#include "utils/logger.h"
|
|
#include "utils/simple_json.h"
|
|
|
|
#include "spatial_matcher.h"
|
|
#include "color_analyzer.h"
|
|
|
|
namespace rk3588 {
|
|
|
|
// 违规类型
|
|
enum class ViolationType {
|
|
NONE,
|
|
MISSING_BOOTS, // 没有鞋
|
|
WRONG_COLOR_BOOTS, // 鞋颜色不对
|
|
};
|
|
|
|
// 违规信息
|
|
struct ViolationInfo {
|
|
ViolationType type = ViolationType::NONE;
|
|
std::string description;
|
|
int track_id = -1;
|
|
float confidence = 0.0f;
|
|
Rect violation_region;
|
|
};
|
|
|
|
// Logic Gate 配置
|
|
struct LogicGateConfig {
|
|
std::string mode = "ppe_boots_check";
|
|
int anchor_class = 6; // Person
|
|
int boots_class = 3; // boots
|
|
|
|
ColorConfig color;
|
|
bool enable_color_check = true;
|
|
|
|
std::string violation_key = "ppe_violation";
|
|
bool pass_through = true;
|
|
bool debug = false;
|
|
};
|
|
|
|
class LogicGateNode : public INode {
|
|
public:
|
|
LogicGateNode() = default;
|
|
~LogicGateNode() override = default;
|
|
|
|
std::string Id() const override { return id_; }
|
|
std::string Type() const override { return "logic_gate"; }
|
|
|
|
bool Init(const SimpleJson& config, const NodeContext& ctx) override {
|
|
id_ = config.ValueOr<std::string>("id", "logic_gate");
|
|
config_ = ParseConfig(config);
|
|
output_queues_ = ctx.output_queues;
|
|
|
|
// 初始化颜色分析器
|
|
if (config_.enable_color_check) {
|
|
color_analyzer_ = std::make_unique<ColorAnalyzer>(config_.color);
|
|
}
|
|
|
|
LogInfo("[LogicGateNode] Initialized, id=" + id_ + " mode=" + config_.mode +
|
|
" output_queues=" + std::to_string(output_queues_.size()));
|
|
return true;
|
|
}
|
|
|
|
bool Start() override {
|
|
LogInfo("[LogicGateNode] started id=" + id_);
|
|
return true;
|
|
}
|
|
|
|
void Stop() override {
|
|
}
|
|
|
|
NodeStatus Process(FramePtr frame) override {
|
|
if (!frame || !frame->det) {
|
|
PushToDownstream(frame);
|
|
return NodeStatus::OK;
|
|
}
|
|
|
|
if (config_.mode == "ppe_boots_check") {
|
|
ProcessPpeBootsCheck(frame);
|
|
}
|
|
|
|
PushToDownstream(frame);
|
|
return NodeStatus::OK;
|
|
}
|
|
|
|
private:
|
|
std::string id_;
|
|
LogicGateConfig config_;
|
|
std::unique_ptr<ColorAnalyzer> color_analyzer_;
|
|
std::vector<std::shared_ptr<SpscQueue<FramePtr>>> output_queues_;
|
|
|
|
LogicGateConfig ParseConfig(const SimpleJson& cfg) {
|
|
LogicGateConfig config;
|
|
|
|
config.mode = cfg.ValueOr<std::string>("mode", "ppe_boots_check");
|
|
config.anchor_class = cfg.ValueOr<int>("anchor_class", 6);
|
|
config.boots_class = cfg.ValueOr<int>("boots_class", 3);
|
|
config.violation_key = cfg.ValueOr<std::string>("violation_key", "ppe_violation");
|
|
config.pass_through = cfg.ValueOr<bool>("pass_through", true);
|
|
config.debug = cfg.ValueOr<bool>("debug", false);
|
|
config.enable_color_check = cfg.ValueOr<bool>("enable_color_check", true);
|
|
|
|
// 解析颜色配置
|
|
if (const SimpleJson* color = cfg.Find("color_check")) {
|
|
std::string method = color->ValueOr<std::string>("method", "hsv");
|
|
if (method == "hsv") config.color.method = ColorMethod::HSV;
|
|
else if (method == "rgb") config.color.method = ColorMethod::RGB;
|
|
else if (method == "brightness") config.color.method = ColorMethod::BRIGHTNESS;
|
|
|
|
config.color.dark_threshold = color->ValueOr<int>("dark_threshold", 80);
|
|
config.color.roi_expand = color->ValueOr<float>("roi_expand", 1.0f);
|
|
config.color.debug_output = config.debug;
|
|
}
|
|
|
|
return config;
|
|
}
|
|
|
|
void PushToDownstream(FramePtr frame) {
|
|
for (auto& q : output_queues_) {
|
|
if (q) q->Push(frame);
|
|
}
|
|
}
|
|
|
|
void ProcessPpeBootsCheck(FramePtr frame) {
|
|
const auto& detections = frame->det->items;
|
|
|
|
// 收集所有人和鞋
|
|
std::vector<Detection> persons;
|
|
std::vector<Detection> boots;
|
|
|
|
for (const auto& det : detections) {
|
|
if (det.cls_id == config_.anchor_class) {
|
|
persons.push_back(det);
|
|
} else if (det.cls_id == config_.boots_class) {
|
|
boots.push_back(det);
|
|
}
|
|
}
|
|
|
|
if (config_.debug) {
|
|
LogInfo("[LogicGateNode] Persons=" + std::to_string(persons.size()) +
|
|
" Boots=" + std::to_string(boots.size()));
|
|
}
|
|
|
|
// 简化逻辑:必须同时检测到人和鞋,才开始判断
|
|
if (persons.empty() || boots.empty()) {
|
|
return;
|
|
}
|
|
|
|
std::vector<ViolationInfo> violations;
|
|
|
|
// 对每只鞋进行颜色检查
|
|
for (const auto& boot : boots) {
|
|
if (config_.enable_color_check && color_analyzer_) {
|
|
auto color_result = color_analyzer_->Analyze(*frame, boot.bbox);
|
|
|
|
if (config_.debug) {
|
|
LogInfo("[LogicGateNode] Boot brightness=" +
|
|
std::to_string(color_result.brightness) +
|
|
" is_dark=" + (color_result.is_dark ? "true" : "false"));
|
|
}
|
|
|
|
if (!color_result.is_dark) {
|
|
// 颜色不对,添加 no_boots 检测框
|
|
Detection no_boots_det;
|
|
no_boots_det.cls_id = 10; // no_boots
|
|
no_boots_det.track_id = std::max(0, boot.track_id);
|
|
no_boots_det.score = std::max(0.5f, color_result.confidence);
|
|
no_boots_det.bbox = boot.bbox;
|
|
|
|
frame->det->items.push_back(no_boots_det);
|
|
|
|
if (config_.debug) {
|
|
LogInfo("[LogicGateNode] VIOLATION: Non-compliant boots color (brightness=" +
|
|
std::to_string(static_cast<int>(color_result.brightness)) +
|
|
") added no_boots(cls=10) track_id=" + std::to_string(no_boots_det.track_id));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
REGISTER_NODE(LogicGateNode, "logic_gate");
|
|
|
|
} // namespace rk3588
|