Tighten person ROI sample filtering

This commit is contained in:
tian 2026-03-17 22:20:53 +08:00
parent 3ce6a96b14
commit feb52b5c2d

View File

@ -40,6 +40,16 @@ PAIR_MAX_AREA_RATIO = 2.5
SINGLE_AREA_RANGE = (0.10, 0.26)
PAIR_AREA_RANGE = (0.18, 0.40)
PERSON_MATCH_X_MARGIN = 0.18
PERSON_MATCH_TOP_RATIO = 0.45
PERSON_MATCH_BOTTOM_RATIO = 1.08
PERSON_MATCH_MIN_IOA = 0.6
PERSON_ROI_TOTAL_AREA_RANGE = (0.015, 0.28)
PERSON_ROI_SINGLE_AREA_RANGE = (0.008, 0.22)
PERSON_ROI_CENTER_Y_RANGE = (0.42, 0.98)
PERSON_ROI_BOTTOM_Y_RANGE = (0.58, 1.0)
PERSON_ROI_CENTER_X_RANGE = (0.03, 0.97)
PERSON_ROI_MAX_BOX_SIZE = (0.78, 0.88)
@dataclass(frozen=True)
@ -246,6 +256,22 @@ def iou(a: Box, b: Box) -> float:
return inter_area / union if union > 0 else 0.0
def intersection_area(a: Box, b: Box) -> float:
inter_x1 = max(a.x1, b.x1)
inter_y1 = max(a.y1, b.y1)
inter_x2 = min(a.x2, b.x2)
inter_y2 = min(a.y2, b.y2)
inter_w = max(0.0, inter_x2 - inter_x1)
inter_h = max(0.0, inter_y2 - inter_y1)
return inter_w * inter_h
def ioa(inner: Box, outer: Box) -> float:
if inner.area <= 0:
return 0.0
return intersection_area(inner, outer) / inner.area
def should_pair(left: Box, right: Box) -> bool:
width_ref = max(left.w, right.w)
height_ref = max(left.h, right.h)
@ -411,13 +437,76 @@ def boxes_in_roi(boxes: list[Box], roi: Box) -> list[Box]:
return included
def make_person_roi_samples(person_boxes: list[Box], image_width: int, image_height: int) -> list[RoiSample]:
def shoe_matches_person(person_box: Box, shoe_box: Box) -> bool:
expanded = Box(
person_box.x1 - (PERSON_MATCH_X_MARGIN * person_box.w),
person_box.y1 + (PERSON_MATCH_TOP_RATIO * person_box.h),
person_box.x2 + (PERSON_MATCH_X_MARGIN * person_box.w),
person_box.y1 + (PERSON_MATCH_BOTTOM_RATIO * person_box.h),
)
if ioa(shoe_box, expanded) < PERSON_MATCH_MIN_IOA:
return False
if shoe_box.cy < person_box.y1 + (PERSON_MATCH_TOP_RATIO * person_box.h):
return False
if shoe_box.y2 < person_box.y1 + (0.55 * person_box.h):
return False
if shoe_box.cx < expanded.x1 or shoe_box.cx > expanded.x2:
return False
if shoe_box.cy > expanded.y2:
return False
return True
def person_roi_boxes_valid(roi_boxes: list[Box], roi: Box) -> bool:
if not roi_boxes or roi.area <= 0:
return False
total_ratio = sum(box.area for box in roi_boxes) / roi.area
if not (PERSON_ROI_TOTAL_AREA_RANGE[0] <= total_ratio <= PERSON_ROI_TOTAL_AREA_RANGE[1]):
return False
for box in roi_boxes:
area_ratio = box.area / roi.area
xc = box.cx / roi.w
yc = box.cy / roi.h
x_ok = PERSON_ROI_CENTER_X_RANGE[0] <= xc <= PERSON_ROI_CENTER_X_RANGE[1]
y_ok = PERSON_ROI_CENTER_Y_RANGE[0] <= yc <= PERSON_ROI_CENTER_Y_RANGE[1]
bottom_ok = PERSON_ROI_BOTTOM_Y_RANGE[0] <= (box.y2 / roi.h) <= PERSON_ROI_BOTTOM_Y_RANGE[1]
size_ok = (box.w / roi.w) <= PERSON_ROI_MAX_BOX_SIZE[0] and (box.h / roi.h) <= PERSON_ROI_MAX_BOX_SIZE[1]
if not x_ok or not y_ok or not bottom_ok or not size_ok:
return False
if not (PERSON_ROI_SINGLE_AREA_RANGE[0] <= area_ratio <= PERSON_ROI_SINGLE_AREA_RANGE[1]):
return False
return True
def make_person_roi_samples(
person_boxes: list[Box],
shoe_boxes: list[Box],
image_width: int,
image_height: int,
) -> list[RoiSample]:
samples: list[RoiSample] = []
for person_idx, person_box in enumerate(person_boxes):
matched_indices = tuple(
shoe_idx
for shoe_idx, shoe_box in enumerate(shoe_boxes)
if shoe_matches_person(person_box, shoe_box)
)
if not matched_indices:
continue
roi = roi_from_person_box(person_box)
roi = clamp_roi(roi, image_width, image_height)
if roi is not None:
samples.append(RoiSample(roi=roi, members=(person_idx,), mode="person"))
if roi is None:
continue
roi_boxes = boxes_in_roi([shoe_boxes[idx] for idx in matched_indices], roi)
if not person_roi_boxes_valid(roi_boxes, roi):
continue
samples.append(RoiSample(roi=roi, members=matched_indices, mode="person"))
return samples
@ -520,12 +609,13 @@ def build_split(source_spec: SourceSpec, output_dir: Path, split: str) -> dict[s
continue
if source_spec.uses_person_boxes and person_boxes:
samples = make_person_roi_samples(person_boxes, width, height)
samples = make_person_roi_samples(person_boxes, shoe_boxes, width, height)
else:
samples = make_roi_samples(shoe_boxes, width, height)
for sample_idx, sample in enumerate(samples):
roi_boxes = boxes_in_roi(shoe_boxes, sample.roi)
candidate_boxes = [shoe_boxes[idx] for idx in sample.members] if sample.members else shoe_boxes
roi_boxes = boxes_in_roi(candidate_boxes, sample.roi)
if not roi_boxes:
continue