OrangePi3588Media/plugins/logic_gate/spatial_matcher.cpp

107 lines
3.5 KiB
C++

#include "spatial_matcher.h"
#include <algorithm>
#include <cmath>
namespace rk3588 {
SpatialMatcher::SpatialMatcher(const SpatialConfig& config) : config_(config) {}
SpatialMatchResult SpatialMatcher::Match(const Detection& anchor, const Detection& target,
int img_w, int img_h) {
SpatialMatchResult result;
// 根据关系类型进行匹配
switch (config_.relation) {
case SpatialRelation::BELOW: {
// 计算脚下区域
result.foot_region = CalculateFootRegion(anchor.bbox, img_w, img_h);
// 计算目标与脚下区域的IOU
result.iou = CalculateIoU(result.foot_region, target.bbox);
result.matched = (result.iou >= config_.iou_threshold);
break;
}
case SpatialRelation::INSIDE: {
result.iou = CalculateIoU(anchor.bbox, target.bbox);
result.matched = (result.iou >= config_.iou_threshold);
break;
}
case SpatialRelation::OVERLAP: {
result.iou = CalculateIoU(anchor.bbox, target.bbox);
result.matched = (result.iou > 0); // 只要有重叠就算匹配
break;
}
case SpatialRelation::NEAR: {
result.distance = CalculateDistance(anchor.bbox, target.bbox, img_w, img_h);
result.matched = (result.distance <= config_.max_distance);
break;
}
default:
break;
}
return result;
}
Rect SpatialMatcher::CalculateFootRegion(const Rect& person_bbox, int img_w, int img_h) {
// 脚下区域:人体框底部向下扩展 foot_ratio 倍高度
float foot_height = person_bbox.h * config_.foot_ratio;
Rect foot;
foot.x = person_bbox.x - person_bbox.w * (config_.expand_ratio - 1.0f) * 0.5f;
foot.y = person_bbox.y + person_bbox.h; // 从人体底部开始
foot.w = person_bbox.w * config_.expand_ratio;
foot.h = foot_height * config_.expand_ratio;
// 确保在图像范围内
foot.x = std::max(0.0f, foot.x);
foot.y = std::max(0.0f, foot.y);
foot.w = std::min(static_cast<float>(img_w) - foot.x, foot.w);
foot.h = std::min(static_cast<float>(img_h) - foot.y, foot.h);
return foot;
}
float SpatialMatcher::CalculateIoU(const Rect& a, const Rect& b) {
// 计算交集
float x1 = std::max(a.x, b.x);
float y1 = std::max(a.y, b.y);
float x2 = std::min(a.x + a.w, b.x + b.w);
float y2 = std::min(a.y + a.h, b.y + b.h);
if (x2 <= x1 || y2 <= y1) {
return 0.0f;
}
float intersection = (x2 - x1) * (y2 - y1);
float area_a = a.w * a.h;
float area_b = b.w * b.h;
float union_area = area_a + area_b - intersection;
return union_area > 0 ? intersection / union_area : 0.0f;
}
float SpatialMatcher::CalculateDistance(const Rect& anchor, const Rect& target, int img_w, int img_h) {
// 计算中心点距离(归一化)
float cx1 = (anchor.x + anchor.w * 0.5f) / img_w;
float cy1 = (anchor.y + anchor.h * 0.5f) / img_h;
float cx2 = (target.x + target.w * 0.5f) / img_w;
float cy2 = (target.y + target.h * 0.5f) / img_h;
float dx = cx1 - cx2;
float dy = cy1 - cy2;
return std::sqrt(dx * dx + dy * dy);
}
bool SpatialMatcher::IsBelow(const Rect& anchor, const Rect& target, int img_w, int img_h) {
// 简单判断:目标的顶部是否在锚点底部的下方
return target.y >= (anchor.y + anchor.h * 0.8f);
}
} // namespace rk3588