DetectionModelTraining/09_build_roi_shoe_dataset.py
2026-03-17 22:20:53 +08:00

684 lines
22 KiB
Python

#!/usr/bin/env python3
"""Build a foot-ROI shoe dataset from existing YOLO shoe datasets.
Preferred training input should come from person-bottom ROIs, matching online inference:
roi_x = x - 0.24w
roi_y = y + 0.64h
roi_w = 1.48w
roi_h = 0.58h
When person boxes are available, this script uses them directly.
When only shoe boxes are available, it falls back to shoe-based ROI approximation
that still tries to match person-bottom input distribution rather than shoe closeups.
"""
from __future__ import annotations
import argparse
import ast
import math
import shutil
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from PIL import Image
DEFAULT_SOURCES = [
"datasets/ppe-person-shoes",
"datasets/openimages-person-shoes-yolo",
"datasets/openimages-shoes-yolo",
]
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
PAIR_MAX_X_GAP_FACTOR = 3.2
PAIR_MAX_Y_GAP_FACTOR = 1.2
PAIR_MIN_AREA_RATIO = 0.4
PAIR_MAX_AREA_RATIO = 2.5
SINGLE_AREA_RANGE = (0.10, 0.26)
PAIR_AREA_RANGE = (0.18, 0.40)
PERSON_MATCH_X_MARGIN = 0.18
PERSON_MATCH_TOP_RATIO = 0.45
PERSON_MATCH_BOTTOM_RATIO = 1.08
PERSON_MATCH_MIN_IOA = 0.6
PERSON_ROI_TOTAL_AREA_RANGE = (0.015, 0.28)
PERSON_ROI_SINGLE_AREA_RANGE = (0.008, 0.22)
PERSON_ROI_CENTER_Y_RANGE = (0.42, 0.98)
PERSON_ROI_BOTTOM_Y_RANGE = (0.58, 1.0)
PERSON_ROI_CENTER_X_RANGE = (0.03, 0.97)
PERSON_ROI_MAX_BOX_SIZE = (0.78, 0.88)
@dataclass(frozen=True)
class Box:
x1: float
y1: float
x2: float
y2: float
@property
def w(self) -> float:
return max(0.0, self.x2 - self.x1)
@property
def h(self) -> float:
return max(0.0, self.y2 - self.y1)
@property
def area(self) -> float:
return self.w * self.h
@property
def cx(self) -> float:
return (self.x1 + self.x2) / 2.0
@property
def cy(self) -> float:
return (self.y1 + self.y2) / 2.0
def clip(self, width: float, height: float) -> "Box | None":
x1 = min(max(self.x1, 0.0), width)
y1 = min(max(self.y1, 0.0), height)
x2 = min(max(self.x2, 0.0), width)
y2 = min(max(self.y2, 0.0), height)
if x2 <= x1 or y2 <= y1:
return None
return Box(x1, y1, x2, y2)
@dataclass(frozen=True)
class RoiSample:
roi: Box
members: tuple[int, ...]
mode: str
@dataclass(frozen=True)
class SourceSpec:
dataset_dir: Path
person_ids: set[int]
shoe_ids: set[int]
uses_person_boxes: bool
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Build a foot-context ROI shoe dataset")
parser.add_argument(
"--sources",
nargs="+",
default=DEFAULT_SOURCES,
help="Source YOLO datasets containing images/<split> and labels/<split>",
)
parser.add_argument(
"--output",
default="datasets/shoe-roi-mix",
help="Output ROI dataset directory",
)
parser.add_argument(
"--clean",
action="store_true",
help="Delete the output directory before rebuilding",
)
return parser.parse_args()
def ensure_output_layout(output_dir: Path) -> None:
for split in ("train", "val", "test"):
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
def find_image(image_dir: Path, stem: str) -> Path | None:
for ext in IMAGE_EXTS:
candidate = image_dir / f"{stem}{ext}"
if candidate.exists():
return candidate
return None
def parse_names_from_yaml(yaml_path: Path) -> dict[int, str]:
names: dict[int, str] = {}
lines = yaml_path.read_text(encoding="utf-8").splitlines()
for index, line in enumerate(lines):
stripped = line.strip()
if not stripped.startswith("names:"):
continue
inline = stripped[len("names:") :].strip()
if inline:
value = ast.literal_eval(inline)
if isinstance(value, list):
return {idx: str(name) for idx, name in enumerate(value)}
for child in lines[index + 1 :]:
if not child.startswith(" "):
break
child_stripped = child.strip()
if ":" not in child_stripped:
continue
key_text, value_text = child_stripped.split(":", 1)
if key_text.strip().isdigit():
names[int(key_text.strip())] = value_text.strip().strip("'\"")
break
return names
def resolve_source_spec(source_dir: Path) -> SourceSpec:
names: dict[int, str] = {}
# Prefer data.yaml because ROI-source exports intentionally rewrite class ids.
for candidate in (source_dir / "data.yaml", source_dir / "dataset.yaml"):
if candidate.exists():
names = parse_names_from_yaml(candidate)
if names:
break
lowered = {idx: name.lower() for idx, name in names.items()}
person_ids = {
idx for idx, name in lowered.items() if name in {"person", "man", "woman", "boy", "girl"}
}
shoe_ids = {
idx
for idx, name in lowered.items()
if name in {"shoe", "footwear", "boot", "boots", "no_boots", "sandal", "high heels"}
}
if not names:
shoe_ids = {0}
if not shoe_ids:
raise RuntimeError(f"未能在 {source_dir} 识别鞋类标签")
return SourceSpec(
dataset_dir=source_dir,
person_ids=person_ids,
shoe_ids=shoe_ids,
uses_person_boxes=bool(person_ids),
)
def load_annotations(
label_path: Path,
image_width: int,
image_height: int,
allowed_ids: set[int],
) -> list[tuple[int, Box]]:
annotations: list[tuple[int, Box]] = []
for raw_line in label_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line:
continue
parts = line.split()
if len(parts) < 5:
continue
class_id = int(parts[0])
if class_id not in allowed_ids:
continue
_, xc, yc, w, h = parts[:5]
box_w = float(w) * image_width
box_h = float(h) * image_height
center_x = float(xc) * image_width
center_y = float(yc) * image_height
box = Box(
center_x - box_w / 2.0,
center_y - box_h / 2.0,
center_x + box_w / 2.0,
center_y + box_h / 2.0,
).clip(image_width, image_height)
if box is not None and box.area > 1.0:
annotations.append((class_id, box))
return dedupe_annotations(annotations)
def dedupe_annotations(
annotations: list[tuple[int, Box]],
iou_threshold: float = 0.9,
) -> list[tuple[int, Box]]:
kept: list[tuple[int, Box]] = []
for class_id, box in sorted(annotations, key=lambda item: item[1].area, reverse=True):
if any(class_id == kept_id and iou(box, existing) >= iou_threshold for kept_id, existing in kept):
continue
kept.append((class_id, box))
return sorted(kept, key=lambda item: (item[1].cx, item[1].cy))
def iou(a: Box, b: Box) -> float:
inter_x1 = max(a.x1, b.x1)
inter_y1 = max(a.y1, b.y1)
inter_x2 = min(a.x2, b.x2)
inter_y2 = min(a.y2, b.y2)
inter_w = max(0.0, inter_x2 - inter_x1)
inter_h = max(0.0, inter_y2 - inter_y1)
inter_area = inter_w * inter_h
if inter_area <= 0:
return 0.0
union = a.area + b.area - inter_area
return inter_area / union if union > 0 else 0.0
def intersection_area(a: Box, b: Box) -> float:
inter_x1 = max(a.x1, b.x1)
inter_y1 = max(a.y1, b.y1)
inter_x2 = min(a.x2, b.x2)
inter_y2 = min(a.y2, b.y2)
inter_w = max(0.0, inter_x2 - inter_x1)
inter_h = max(0.0, inter_y2 - inter_y1)
return inter_w * inter_h
def ioa(inner: Box, outer: Box) -> float:
if inner.area <= 0:
return 0.0
return intersection_area(inner, outer) / inner.area
def should_pair(left: Box, right: Box) -> bool:
width_ref = max(left.w, right.w)
height_ref = max(left.h, right.h)
if width_ref <= 0 or height_ref <= 0:
return False
dx = abs(left.cx - right.cx)
dy = abs(left.cy - right.cy)
area_ratio = left.area / right.area if right.area > 0 else math.inf
return (
dx <= width_ref * PAIR_MAX_X_GAP_FACTOR
and dy <= height_ref * PAIR_MAX_Y_GAP_FACTOR
and PAIR_MIN_AREA_RATIO <= area_ratio <= PAIR_MAX_AREA_RATIO
)
def greedy_group_boxes(boxes: list[Box]) -> list[tuple[int, ...]]:
if len(boxes) < 2:
return [(idx,) for idx in range(len(boxes))]
candidates: list[tuple[float, int, int]] = []
for i in range(len(boxes)):
for j in range(i + 1, len(boxes)):
if not should_pair(boxes[i], boxes[j]):
continue
dx = abs(boxes[i].cx - boxes[j].cx)
dy = abs(boxes[i].cy - boxes[j].cy)
score = dx + (0.5 * dy)
candidates.append((score, i, j))
used: set[int] = set()
groups: list[tuple[int, ...]] = []
for _, i, j in sorted(candidates, key=lambda item: item[0]):
if i in used or j in used:
continue
used.add(i)
used.add(j)
groups.append((i, j))
for idx in range(len(boxes)):
if idx not in used:
groups.append((idx,))
return groups
def estimate_person_from_single_shoe(box: Box) -> Box:
"""Estimate a loose full-body box from a single shoe box."""
person_w = max(4.2 * box.w, 2.8 * box.h)
person_h = max(7.6 * box.h, 3.4 * box.w)
person_cx = box.cx
# Put shoe close to the bottom of the estimated person, leaving small ground margin.
person_y2 = box.y2 + (0.08 * person_h)
person_x1 = person_cx - (person_w / 2.0)
person_y1 = person_y2 - person_h
return Box(person_x1, person_y1, person_x1 + person_w, person_y2)
def estimate_person_from_pair(boxes: list[Box], group: tuple[int, int]) -> Box:
"""Estimate a loose full-body box from a visible pair of shoes."""
first = boxes[group[0]]
second = boxes[group[1]]
union_x1 = min(first.x1, second.x1)
union_y1 = min(first.y1, second.y1)
union_x2 = max(first.x2, second.x2)
union_y2 = max(first.y2, second.y2)
union_w = union_x2 - union_x1
union_h = union_y2 - union_y1
person_w = max(1.95 * union_w, 2.6 * union_h)
person_h = max(7.8 * union_h, 2.9 * union_w)
person_cx = (union_x1 + union_x2) / 2.0
person_y2 = union_y2 + (0.08 * person_h)
person_x1 = person_cx - (person_w / 2.0)
person_y1 = person_y2 - person_h
return Box(person_x1, person_y1, person_x1 + person_w, person_y2)
def roi_from_person_box(person_box: Box) -> Box:
"""Apply the online person-bottom ROI rule and loosen it slightly."""
roi_x = person_box.x1 - (0.24 * person_box.w)
roi_y = person_box.y1 + (0.64 * person_box.h)
roi_w = 1.48 * person_box.w
roi_h = 0.58 * person_box.h
# Slightly enlarge to keep more trouser leg, ground, and side context than online.
roi_x -= 0.08 * roi_w
roi_y -= 0.08 * roi_h
roi_w *= 1.16
roi_h *= 1.18
return Box(roi_x, roi_y, roi_x + roi_w, roi_y + roi_h)
def clamp_roi(roi: Box, image_width: int, image_height: int) -> Box | None:
clipped = roi.clip(float(image_width), float(image_height))
if clipped is None:
return None
x1 = int(math.floor(clipped.x1))
y1 = int(math.floor(clipped.y1))
x2 = int(math.ceil(clipped.x2))
y2 = int(math.ceil(clipped.y2))
x1 = max(0, min(x1, image_width - 1))
y1 = max(0, min(y1, image_height - 1))
x2 = max(x1 + 1, min(x2, image_width))
y2 = max(y1 + 1, min(y2, image_height))
return Box(float(x1), float(y1), float(x2), float(y2))
def resize_roi_to_ratio(
roi: Box,
image_width: int,
image_height: int,
object_area: float,
min_ratio: float,
max_ratio: float,
) -> Box | None:
if object_area <= 0:
return None
adjusted = roi
target_ratio = (min_ratio + max_ratio) / 2.0
for _ in range(3):
roi_area = adjusted.area
if roi_area <= 0:
return None
ratio = object_area / roi_area
if min_ratio <= ratio <= max_ratio:
break
scale = math.sqrt(ratio / target_ratio)
if ratio < min_ratio:
scale = max(0.6, min(0.95, scale))
else:
scale = min(1.8, max(1.05, scale))
new_w = adjusted.w * scale
new_h = adjusted.h * scale
cx = adjusted.cx
cy = adjusted.cy
adjusted = Box(cx - new_w / 2.0, cy - new_h / 2.0, cx + new_w / 2.0, cy + new_h / 2.0)
adjusted = clamp_roi(adjusted, image_width, image_height)
if adjusted is None:
return None
return adjusted
def boxes_in_roi(boxes: list[Box], roi: Box) -> list[Box]:
included: list[Box] = []
for box in boxes:
if not (roi.x1 <= box.cx <= roi.x2 and roi.y1 <= box.cy <= roi.y2):
continue
clipped = Box(
box.x1 - roi.x1,
box.y1 - roi.y1,
box.x2 - roi.x1,
box.y2 - roi.y1,
).clip(roi.w, roi.h)
if clipped is not None and clipped.area > 4.0:
included.append(clipped)
return included
def shoe_matches_person(person_box: Box, shoe_box: Box) -> bool:
expanded = Box(
person_box.x1 - (PERSON_MATCH_X_MARGIN * person_box.w),
person_box.y1 + (PERSON_MATCH_TOP_RATIO * person_box.h),
person_box.x2 + (PERSON_MATCH_X_MARGIN * person_box.w),
person_box.y1 + (PERSON_MATCH_BOTTOM_RATIO * person_box.h),
)
if ioa(shoe_box, expanded) < PERSON_MATCH_MIN_IOA:
return False
if shoe_box.cy < person_box.y1 + (PERSON_MATCH_TOP_RATIO * person_box.h):
return False
if shoe_box.y2 < person_box.y1 + (0.55 * person_box.h):
return False
if shoe_box.cx < expanded.x1 or shoe_box.cx > expanded.x2:
return False
if shoe_box.cy > expanded.y2:
return False
return True
def person_roi_boxes_valid(roi_boxes: list[Box], roi: Box) -> bool:
if not roi_boxes or roi.area <= 0:
return False
total_ratio = sum(box.area for box in roi_boxes) / roi.area
if not (PERSON_ROI_TOTAL_AREA_RANGE[0] <= total_ratio <= PERSON_ROI_TOTAL_AREA_RANGE[1]):
return False
for box in roi_boxes:
area_ratio = box.area / roi.area
xc = box.cx / roi.w
yc = box.cy / roi.h
x_ok = PERSON_ROI_CENTER_X_RANGE[0] <= xc <= PERSON_ROI_CENTER_X_RANGE[1]
y_ok = PERSON_ROI_CENTER_Y_RANGE[0] <= yc <= PERSON_ROI_CENTER_Y_RANGE[1]
bottom_ok = PERSON_ROI_BOTTOM_Y_RANGE[0] <= (box.y2 / roi.h) <= PERSON_ROI_BOTTOM_Y_RANGE[1]
size_ok = (box.w / roi.w) <= PERSON_ROI_MAX_BOX_SIZE[0] and (box.h / roi.h) <= PERSON_ROI_MAX_BOX_SIZE[1]
if not x_ok or not y_ok or not bottom_ok or not size_ok:
return False
if not (PERSON_ROI_SINGLE_AREA_RANGE[0] <= area_ratio <= PERSON_ROI_SINGLE_AREA_RANGE[1]):
return False
return True
def make_person_roi_samples(
person_boxes: list[Box],
shoe_boxes: list[Box],
image_width: int,
image_height: int,
) -> list[RoiSample]:
samples: list[RoiSample] = []
for person_idx, person_box in enumerate(person_boxes):
matched_indices = tuple(
shoe_idx
for shoe_idx, shoe_box in enumerate(shoe_boxes)
if shoe_matches_person(person_box, shoe_box)
)
if not matched_indices:
continue
roi = roi_from_person_box(person_box)
roi = clamp_roi(roi, image_width, image_height)
if roi is None:
continue
roi_boxes = boxes_in_roi([shoe_boxes[idx] for idx in matched_indices], roi)
if not person_roi_boxes_valid(roi_boxes, roi):
continue
samples.append(RoiSample(roi=roi, members=matched_indices, mode="person"))
return samples
def make_roi_samples(boxes: list[Box], image_width: int, image_height: int) -> list[RoiSample]:
samples: list[RoiSample] = []
groups = greedy_group_boxes(boxes)
for group in groups:
if len(group) == 2:
person_box = estimate_person_from_pair(boxes, group)
roi = roi_from_person_box(person_box)
area_range = PAIR_AREA_RANGE
mode = "pair"
else:
person_box = estimate_person_from_single_shoe(boxes[group[0]])
roi = roi_from_person_box(person_box)
area_range = SINGLE_AREA_RANGE
mode = "single"
roi = clamp_roi(roi, image_width, image_height)
if roi is None:
continue
object_area = sum(boxes[idx].area for idx in group)
roi = resize_roi_to_ratio(roi, image_width, image_height, object_area, *area_range)
if roi is None:
continue
samples.append(RoiSample(roi=roi, members=group, mode=mode))
return samples
def to_yolo_lines(boxes: list[Box], roi_w: float, roi_h: float) -> list[str]:
lines: list[str] = []
for box in boxes:
xc = ((box.x1 + box.x2) / 2.0) / roi_w
yc = ((box.y1 + box.y2) / 2.0) / roi_h
bw = box.w / roi_w
bh = box.h / roi_h
lines.append(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}")
return lines
def write_yaml(output_dir: Path, sources: list[str]) -> None:
yaml_path = output_dir / "data.yaml"
source_names = ", ".join(Path(item).name for item in sources)
yaml_path.write_text(
"\n".join(
[
"# ROI shoe training mix",
"",
f"path: {output_dir.resolve().as_posix()}",
"train: images/train",
"val: images/val",
"test: images/test",
"",
"nc: 1",
"names: ['shoe']",
"",
"dataset_info:",
" name: shoe-roi-mix",
" task: detect_shoe_roi",
f" source: {source_names}",
" note: prefer person-bottom ROIs; current public data uses shoe-box fallback crops",
"",
]
),
encoding="utf-8",
)
def build_split(source_spec: SourceSpec, output_dir: Path, split: str) -> dict[str, int]:
source_dir = source_spec.dataset_dir
image_dir = source_dir / "images" / split
label_dir = source_dir / "labels" / split
if not image_dir.exists() or not label_dir.exists():
return {"images": 0, "boxes": 0, "single": 0, "pair": 0, "person": 0}
stats = {"images": 0, "boxes": 0, "single": 0, "pair": 0, "person": 0}
prefix = source_dir.name.replace("-", "_")
for label_path in sorted(label_dir.glob("*.txt")):
image_path = find_image(image_dir, label_path.stem)
if image_path is None:
continue
with Image.open(image_path) as image:
image = image.convert("RGB")
width, height = image.size
annotations = load_annotations(
label_path,
width,
height,
source_spec.person_ids | source_spec.shoe_ids,
)
if not annotations:
continue
shoe_boxes = [box for class_id, box in annotations if class_id in source_spec.shoe_ids]
person_boxes = [box for class_id, box in annotations if class_id in source_spec.person_ids]
if not shoe_boxes:
continue
if source_spec.uses_person_boxes and person_boxes:
samples = make_person_roi_samples(person_boxes, shoe_boxes, width, height)
else:
samples = make_roi_samples(shoe_boxes, width, height)
for sample_idx, sample in enumerate(samples):
candidate_boxes = [shoe_boxes[idx] for idx in sample.members] if sample.members else shoe_boxes
roi_boxes = boxes_in_roi(candidate_boxes, sample.roi)
if not roi_boxes:
continue
out_stem = f"{prefix}_{label_path.stem}_{sample.mode}_{sample_idx:02d}"
dst_image = output_dir / "images" / split / f"{out_stem}.jpg"
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
crop = image.crop((sample.roi.x1, sample.roi.y1, sample.roi.x2, sample.roi.y2))
crop.save(dst_image, quality=95)
yolo_lines = to_yolo_lines(roi_boxes, sample.roi.w, sample.roi.h)
dst_label.write_text("\n".join(yolo_lines) + "\n", encoding="utf-8")
stats["images"] += 1
stats["boxes"] += len(roi_boxes)
stats[sample.mode] += 1
return stats
def main() -> None:
args = parse_args()
output_dir = Path(args.output)
if args.clean and output_dir.exists():
shutil.rmtree(output_dir)
ensure_output_layout(output_dir)
summary: dict[str, dict[str, dict[str, int]]] = defaultdict(dict)
totals = {"images": 0, "boxes": 0, "single": 0, "pair": 0, "person": 0}
for source in args.sources:
source_dir = Path(source)
if not source_dir.exists():
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
source_spec = resolve_source_spec(source_dir)
for split in ("train", "val", "test"):
stats = build_split(source_spec, output_dir, split)
summary[source_dir.name][split] = stats
for key in totals:
totals[key] += stats.get(key, 0)
write_yaml(output_dir, args.sources)
print(f"Output dataset: {output_dir.resolve()}")
for source_name, split_map in summary.items():
print(f"[{source_name}]")
for split in ("train", "val", "test"):
stats = split_map.get(split, {"images": 0, "boxes": 0, "single": 0, "pair": 0, "person": 0})
print(
f" {split}: rois={stats['images']} boxes={stats['boxes']} "
f"person={stats['person']} single={stats['single']} pair={stats['pair']}"
)
print(
"Total:"
f" rois={totals['images']} boxes={totals['boxes']}"
f" single={totals['single']} pair={totals['pair']}"
)
if __name__ == "__main__":
main()