146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Build a merged single-class public shoe dataset for YOLO training."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import shutil
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
|
|
DEFAULT_SOURCES = [
|
|
"datasets/openimages-shoes-yolo",
|
|
"datasets/ppe-shoes",
|
|
]
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Merge public shoe datasets into one YOLO dataset")
|
|
parser.add_argument(
|
|
"--sources",
|
|
nargs="+",
|
|
default=DEFAULT_SOURCES,
|
|
help="Source dataset directories containing images/<split> and labels/<split>",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default="datasets/shoe-public-mix",
|
|
help="Output merged dataset directory",
|
|
)
|
|
parser.add_argument(
|
|
"--clean",
|
|
action="store_true",
|
|
help="Delete the output directory before rebuilding",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def ensure_output_layout(output_dir: Path) -> None:
|
|
for split in ("train", "val", "test"):
|
|
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
|
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def copy_split(source_dir: Path, output_dir: Path, split: str) -> tuple[int, int]:
|
|
image_dir = source_dir / "images" / split
|
|
label_dir = source_dir / "labels" / split
|
|
if not image_dir.exists() or not label_dir.exists():
|
|
return 0, 0
|
|
|
|
images_copied = 0
|
|
boxes_copied = 0
|
|
prefix = source_dir.name.replace("-", "_")
|
|
|
|
for label_file in sorted(label_dir.glob("*.txt")):
|
|
image_file = None
|
|
for ext in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
|
candidate = image_dir / f"{label_file.stem}{ext}"
|
|
if candidate.exists():
|
|
image_file = candidate
|
|
break
|
|
|
|
if image_file is None:
|
|
continue
|
|
|
|
lines = [line.strip() for line in label_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
if not lines:
|
|
continue
|
|
|
|
out_stem = f"{prefix}_{label_file.stem}"
|
|
dst_image = output_dir / "images" / split / f"{out_stem}{image_file.suffix.lower()}"
|
|
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
|
|
|
|
shutil.copy2(image_file, dst_image)
|
|
dst_label.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
|
|
|
images_copied += 1
|
|
boxes_copied += len(lines)
|
|
|
|
return images_copied, boxes_copied
|
|
|
|
|
|
def write_yaml(output_dir: Path) -> None:
|
|
yaml_path = output_dir / "data.yaml"
|
|
yaml_path.write_text(
|
|
"\n".join(
|
|
[
|
|
"# Public shoe training mix",
|
|
"",
|
|
f"path: {output_dir.resolve().as_posix()}",
|
|
"train: images/train",
|
|
"val: images/val",
|
|
"test: images/test",
|
|
"",
|
|
"nc: 1",
|
|
"names: ['shoe']",
|
|
"",
|
|
"dataset_info:",
|
|
" name: shoe-public-mix",
|
|
" task: detect_shoe",
|
|
" note: merged Open Images shoe data and PPE shoe subset",
|
|
"",
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
output_dir = Path(args.output)
|
|
|
|
if args.clean and output_dir.exists():
|
|
shutil.rmtree(output_dir)
|
|
|
|
ensure_output_layout(output_dir)
|
|
|
|
summary: dict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
|
for source in args.sources:
|
|
source_dir = Path(source)
|
|
if not source_dir.exists():
|
|
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
|
|
|
|
for split in ("train", "val", "test"):
|
|
summary[source_dir.name][split] = copy_split(source_dir, output_dir, split)
|
|
|
|
write_yaml(output_dir)
|
|
|
|
print(f"Output dataset: {output_dir.resolve()}")
|
|
total_images = 0
|
|
total_boxes = 0
|
|
for source_name, split_map in summary.items():
|
|
print(f"[{source_name}]")
|
|
for split in ("train", "val", "test"):
|
|
images_copied, boxes_copied = split_map.get(split, (0, 0))
|
|
total_images += images_copied
|
|
total_boxes += boxes_copied
|
|
print(f" {split}: images={images_copied} boxes={boxes_copied}")
|
|
|
|
print(f"Total images: {total_images}")
|
|
print(f"Total boxes: {total_boxes}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|