109 lines
3.5 KiB
C++
109 lines
3.5 KiB
C++
#include "spatial_matcher.h"
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
|
|
namespace rk3588 {
|
|
|
|
SpatialMatcher::SpatialMatcher(const SpatialConfig& config) : config_(config) {}
|
|
|
|
SpatialMatchResult SpatialMatcher::Match(const Detection& anchor, const Detection& target,
|
|
int img_w, int img_h) {
|
|
SpatialMatchResult result;
|
|
|
|
// 根据关系类型进行匹配
|
|
switch (config_.relation) {
|
|
case SpatialRelation::BELOW: {
|
|
// 计算脚下区域
|
|
result.foot_region = CalculateFootRegion(anchor.bbox, img_w, img_h);
|
|
|
|
// 计算目标与脚下区域的IOU
|
|
result.iou = CalculateIoU(result.foot_region, target.bbox);
|
|
result.matched = (result.iou >= config_.iou_threshold);
|
|
break;
|
|
}
|
|
|
|
case SpatialRelation::INSIDE: {
|
|
result.iou = CalculateIoU(anchor.bbox, target.bbox);
|
|
result.matched = (result.iou >= config_.iou_threshold);
|
|
break;
|
|
}
|
|
|
|
case SpatialRelation::OVERLAP: {
|
|
result.iou = CalculateIoU(anchor.bbox, target.bbox);
|
|
result.matched = (result.iou > 0); // 只要有重叠就算匹配
|
|
break;
|
|
}
|
|
|
|
case SpatialRelation::NEAR: {
|
|
result.distance = CalculateDistance(anchor.bbox, target.bbox, img_w, img_h);
|
|
result.matched = (result.distance <= config_.max_distance);
|
|
break;
|
|
}
|
|
|
|
default:
|
|
break;
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
Rect SpatialMatcher::CalculateFootRegion(const Rect& person_bbox, int img_w, int img_h) {
|
|
// 脚下区域:人体框底部向下扩展 foot_ratio 倍高度
|
|
float foot_height = person_bbox.h * config_.foot_ratio;
|
|
|
|
Rect foot;
|
|
foot.x = person_bbox.x - person_bbox.w * (config_.expand_ratio - 1.0f) * 0.5f;
|
|
foot.y = person_bbox.y + person_bbox.h; // 从人体底部开始
|
|
foot.w = person_bbox.w * config_.expand_ratio;
|
|
foot.h = foot_height * config_.expand_ratio;
|
|
|
|
// 确保在图像范围内
|
|
foot.x = std::max(0.0f, foot.x);
|
|
foot.y = std::max(0.0f, foot.y);
|
|
foot.w = std::min(static_cast<float>(img_w) - foot.x, foot.w);
|
|
foot.h = std::min(static_cast<float>(img_h) - foot.y, foot.h);
|
|
|
|
return foot;
|
|
}
|
|
|
|
float SpatialMatcher::CalculateIoU(const Rect& a, const Rect& b) {
|
|
// 计算交集
|
|
float x1 = std::max(a.x, b.x);
|
|
float y1 = std::max(a.y, b.y);
|
|
float x2 = std::min(a.x + a.w, b.x + b.w);
|
|
float y2 = std::min(a.y + a.h, b.y + b.h);
|
|
|
|
if (x2 <= x1 || y2 <= y1) {
|
|
return 0.0f;
|
|
}
|
|
|
|
float intersection = (x2 - x1) * (y2 - y1);
|
|
float area_a = a.w * a.h;
|
|
float area_b = b.w * b.h;
|
|
float union_area = area_a + area_b - intersection;
|
|
|
|
return union_area > 0 ? intersection / union_area : 0.0f;
|
|
}
|
|
|
|
float SpatialMatcher::CalculateDistance(const Rect& anchor, const Rect& target, int img_w, int img_h) {
|
|
// 计算中心点距离(归一化)
|
|
float cx1 = (anchor.x + anchor.w * 0.5f) / img_w;
|
|
float cy1 = (anchor.y + anchor.h * 0.5f) / img_h;
|
|
float cx2 = (target.x + target.w * 0.5f) / img_w;
|
|
float cy2 = (target.y + target.h * 0.5f) / img_h;
|
|
|
|
float dx = cx1 - cx2;
|
|
float dy = cy1 - cy2;
|
|
|
|
return std::sqrt(dx * dx + dy * dy);
|
|
}
|
|
|
|
bool SpatialMatcher::IsBelow(const Rect& anchor, const Rect& target, int img_w, int img_h) {
|
|
(void)img_w;
|
|
(void)img_h;
|
|
// 简单判断:目标的顶部是否在锚点底部的下方
|
|
return target.y >= (anchor.y + anchor.h * 0.8f);
|
|
}
|
|
|
|
} // namespace rk3588
|