OrangePi3588Media/plugins/logic_gate/logic_gate_node.cpp

193 lines
6.3 KiB
C++

#include <map>
#include <vector>
#include <memory>
#include <chrono>
#include <algorithm>
#include <cstring>
#include "node.h"
#include "frame/frame.h"
#include "utils/logger.h"
#include "utils/simple_json.h"
#include "spatial_matcher.h"
#include "color_analyzer.h"
namespace rk3588 {
// 违规类型
enum class ViolationType {
NONE,
MISSING_BOOTS, // 没有鞋
WRONG_COLOR_BOOTS, // 鞋颜色不对
};
// 违规信息
struct ViolationInfo {
ViolationType type = ViolationType::NONE;
std::string description;
int track_id = -1;
float confidence = 0.0f;
Rect violation_region;
};
// Logic Gate 配置
struct LogicGateConfig {
std::string mode = "ppe_boots_check";
int anchor_class = 6; // Person
int boots_class = 3; // boots
ColorConfig color;
bool enable_color_check = true;
std::string violation_key = "ppe_violation";
bool pass_through = true;
bool debug = false;
};
class LogicGateNode : public INode {
public:
LogicGateNode() = default;
~LogicGateNode() override = default;
std::string Id() const override { return id_; }
std::string Type() const override { return "logic_gate"; }
bool Init(const SimpleJson& config, const NodeContext& ctx) override {
id_ = config.ValueOr<std::string>("id", "logic_gate");
config_ = ParseConfig(config);
output_queues_ = ctx.output_queues;
// 初始化颜色分析器
if (config_.enable_color_check) {
color_analyzer_ = std::make_unique<ColorAnalyzer>(config_.color);
}
LogInfo("[LogicGateNode] Initialized, id=" + id_ + " mode=" + config_.mode +
" output_queues=" + std::to_string(output_queues_.size()));
return true;
}
bool Start() override {
LogInfo("[LogicGateNode] started id=" + id_);
return true;
}
void Stop() override {
}
NodeStatus Process(FramePtr frame) override {
if (!frame || !frame->det) {
PushToDownstream(frame);
return NodeStatus::OK;
}
if (config_.mode == "ppe_boots_check") {
ProcessPpeBootsCheck(frame);
}
PushToDownstream(frame);
return NodeStatus::OK;
}
private:
std::string id_;
LogicGateConfig config_;
std::unique_ptr<ColorAnalyzer> color_analyzer_;
std::vector<std::shared_ptr<SpscQueue<FramePtr>>> output_queues_;
LogicGateConfig ParseConfig(const SimpleJson& cfg) {
LogicGateConfig config;
config.mode = cfg.ValueOr<std::string>("mode", "ppe_boots_check");
config.anchor_class = cfg.ValueOr<int>("anchor_class", 6);
config.boots_class = cfg.ValueOr<int>("boots_class", 3);
config.violation_key = cfg.ValueOr<std::string>("violation_key", "ppe_violation");
config.pass_through = cfg.ValueOr<bool>("pass_through", true);
config.debug = cfg.ValueOr<bool>("debug", false);
config.enable_color_check = cfg.ValueOr<bool>("enable_color_check", true);
// 解析颜色配置
if (const SimpleJson* color = cfg.Find("color_check")) {
std::string method = color->ValueOr<std::string>("method", "hsv");
if (method == "hsv") config.color.method = ColorMethod::HSV;
else if (method == "rgb") config.color.method = ColorMethod::RGB;
else if (method == "brightness") config.color.method = ColorMethod::BRIGHTNESS;
config.color.dark_threshold = color->ValueOr<int>("dark_threshold", 80);
config.color.roi_expand = color->ValueOr<float>("roi_expand", 1.0f);
config.color.debug_output = config.debug;
}
return config;
}
void PushToDownstream(FramePtr frame) {
for (auto& q : output_queues_) {
if (q) q->Push(frame);
}
}
void ProcessPpeBootsCheck(FramePtr frame) {
const auto& detections = frame->det->items;
// 收集所有人和鞋
std::vector<Detection> persons;
std::vector<Detection> boots;
for (const auto& det : detections) {
if (det.cls_id == config_.anchor_class) {
persons.push_back(det);
} else if (det.cls_id == config_.boots_class) {
boots.push_back(det);
}
}
if (config_.debug) {
LogInfo("[LogicGateNode] Persons=" + std::to_string(persons.size()) +
" Boots=" + std::to_string(boots.size()));
}
// 简化逻辑:必须同时检测到人和鞋,才开始判断
if (persons.empty() || boots.empty()) {
return;
}
std::vector<ViolationInfo> violations;
// 对每只鞋进行颜色检查
for (const auto& boot : boots) {
if (config_.enable_color_check && color_analyzer_) {
auto color_result = color_analyzer_->Analyze(*frame, boot.bbox);
if (config_.debug) {
LogInfo("[LogicGateNode] Boot brightness=" +
std::to_string(color_result.brightness) +
" is_dark=" + (color_result.is_dark ? "true" : "false"));
}
if (!color_result.is_dark) {
// 颜色不对,添加 no_boots 检测框
Detection no_boots_det;
no_boots_det.cls_id = 10; // no_boots
no_boots_det.track_id = std::max(0, boot.track_id);
no_boots_det.score = std::max(0.5f, color_result.confidence);
no_boots_det.bbox = boot.bbox;
frame->det->items.push_back(no_boots_det);
if (config_.debug) {
LogInfo("[LogicGateNode] VIOLATION: Non-compliant boots color (brightness=" +
std::to_string(static_cast<int>(color_result.brightness)) +
") added no_boots(cls=10) track_id=" + std::to_string(no_boots_det.track_id));
}
}
}
}
}
};
REGISTER_NODE(LogicGateNode, "logic_gate");
} // namespace rk3588