DetectionModelTraining/07_build_public_shoe_dataset.py

146 lines
4.3 KiB
Python

#!/usr/bin/env python3
"""Build a merged single-class public shoe dataset for YOLO training."""
from __future__ import annotations
import argparse
import shutil
from collections import defaultdict
from pathlib import Path
DEFAULT_SOURCES = [
"datasets/openimages-shoes-yolo",
"datasets/ppe-shoes",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Merge public shoe datasets into one YOLO dataset")
parser.add_argument(
"--sources",
nargs="+",
default=DEFAULT_SOURCES,
help="Source dataset directories containing images/<split> and labels/<split>",
)
parser.add_argument(
"--output",
default="datasets/shoe-public-mix",
help="Output merged dataset directory",
)
parser.add_argument(
"--clean",
action="store_true",
help="Delete the output directory before rebuilding",
)
return parser.parse_args()
def ensure_output_layout(output_dir: Path) -> None:
for split in ("train", "val", "test"):
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
def copy_split(source_dir: Path, output_dir: Path, split: str) -> tuple[int, int]:
image_dir = source_dir / "images" / split
label_dir = source_dir / "labels" / split
if not image_dir.exists() or not label_dir.exists():
return 0, 0
images_copied = 0
boxes_copied = 0
prefix = source_dir.name.replace("-", "_")
for label_file in sorted(label_dir.glob("*.txt")):
image_file = None
for ext in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
candidate = image_dir / f"{label_file.stem}{ext}"
if candidate.exists():
image_file = candidate
break
if image_file is None:
continue
lines = [line.strip() for line in label_file.read_text(encoding="utf-8").splitlines() if line.strip()]
if not lines:
continue
out_stem = f"{prefix}_{label_file.stem}"
dst_image = output_dir / "images" / split / f"{out_stem}{image_file.suffix.lower()}"
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
shutil.copy2(image_file, dst_image)
dst_label.write_text("\n".join(lines) + "\n", encoding="utf-8")
images_copied += 1
boxes_copied += len(lines)
return images_copied, boxes_copied
def write_yaml(output_dir: Path) -> None:
yaml_path = output_dir / "data.yaml"
yaml_path.write_text(
"\n".join(
[
"# Public shoe training mix",
"",
f"path: {output_dir.resolve().as_posix()}",
"train: images/train",
"val: images/val",
"test: images/test",
"",
"nc: 1",
"names: ['shoe']",
"",
"dataset_info:",
" name: shoe-public-mix",
" task: detect_shoe",
" note: merged Open Images shoe data and PPE shoe subset",
"",
]
),
encoding="utf-8",
)
def main() -> None:
args = parse_args()
output_dir = Path(args.output)
if args.clean and output_dir.exists():
shutil.rmtree(output_dir)
ensure_output_layout(output_dir)
summary: dict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
for source in args.sources:
source_dir = Path(source)
if not source_dir.exists():
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
for split in ("train", "val", "test"):
summary[source_dir.name][split] = copy_split(source_dir, output_dir, split)
write_yaml(output_dir)
print(f"Output dataset: {output_dir.resolve()}")
total_images = 0
total_boxes = 0
for source_name, split_map in summary.items():
print(f"[{source_name}]")
for split in ("train", "val", "test"):
images_copied, boxes_copied = split_map.get(split, (0, 0))
total_images += images_copied
total_boxes += boxes_copied
print(f" {split}: images={images_copied} boxes={boxes_copied}")
print(f"Total images: {total_images}")
print(f"Total boxes: {total_boxes}")
if __name__ == "__main__":
main()