#!/usr/bin/env python3 """Train multiple YOLO models on the merged shoe dataset and compare ROI performance.""" from __future__ import annotations import argparse import json import os import sys from pathlib import Path from statistics import mean REPO_ROOT = Path(__file__).resolve().parent PYDEPS = REPO_ROOT / ".pydeps" if PYDEPS.exists(): sys.path.insert(0, str(PYDEPS)) from ultralytics import YOLO # noqa: E402 DEFAULT_MODELS = ["yolov8s.pt", "yolo11s.pt", "yolo26s.pt"] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train and compare YOLO shoe detectors") parser.add_argument("--data", default="datasets/shoe-public-mix/data.yaml", help="Training dataset yaml") parser.add_argument("--roi-dir", default="datasets/roi-shoes", help="Real-world ROI evaluation directory") parser.add_argument("--project", default="runs/shoe_compare", help="Ultralytics project directory") parser.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Model checkpoints or aliases") parser.add_argument("--epochs", type=int, default=40, help="Training epochs for each model") parser.add_argument("--imgsz", type=int, default=640, help="Training image size") parser.add_argument("--batch", type=int, default=16, help="Training batch size") parser.add_argument("--workers", type=int, default=8, help="Data loader workers") parser.add_argument("--device", default="0", help="Training device, e.g. 0 or cpu") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--patience", type=int, default=20, help="Early stopping patience") return parser.parse_args() def safe_name(model_name: str) -> str: stem = Path(model_name).stem return stem.replace(".", "_").replace("-", "_") def evaluate_roi(model_path: Path, roi_dir: Path, save_dir: Path, device: str) -> dict: model = YOLO(str(model_path)) results = model.predict( source=str(roi_dir), conf=0.1, save=True, project=str(save_dir.parent), name=save_dir.name, exist_ok=True, verbose=False, device=device, ) per_image = [] max_confs = [] hit_count = 0 for result in results: n = 0 if result.boxes is None else len(result.boxes) hit = n > 0 if hit: hit_count += 1 max_confs.append(float(result.boxes.conf.max().item())) per_image.append( { "image": Path(result.path).name, "detections": n, "max_conf": float(result.boxes.conf.max().item()) if hit else 0.0, } ) summary = { "roi_total": len(per_image), "roi_hits": hit_count, "roi_hit_rate": hit_count / len(per_image) if per_image else 0.0, "roi_mean_max_conf": mean(max_confs) if max_confs else 0.0, "per_image": per_image, } return summary def main() -> None: args = parse_args() os.environ.setdefault("YOLO_CONFIG_DIR", str(REPO_ROOT / ".ultralytics")) data = Path(args.data) roi_dir = Path(args.roi_dir) project = Path(args.project) project.mkdir(parents=True, exist_ok=True) if not data.exists(): raise FileNotFoundError(f"Dataset yaml not found: {data}") if not roi_dir.exists(): raise FileNotFoundError(f"ROI directory not found: {roi_dir}") compare_summary = [] for model_name in args.models: run_name = f"{safe_name(model_name)}_shoe_{args.imgsz}" print("=" * 80) print(f"Training {model_name} -> {run_name}") print("=" * 80) model = YOLO(model_name) train_results = model.train( data=str(data.resolve()), epochs=args.epochs, imgsz=args.imgsz, batch=args.batch, workers=args.workers, device=args.device, seed=args.seed, patience=args.patience, project=str(project.resolve()), name=run_name, exist_ok=True, pretrained=True, cos_lr=True, amp=True, verbose=True, ) best_path = Path(train_results.save_dir) / "weights" / "best.pt" roi_save_dir = Path(train_results.save_dir) / "roi_eval" roi_summary = evaluate_roi(best_path, roi_dir, roi_save_dir, args.device) record = { "model": model_name, "run_name": run_name, "save_dir": str(Path(train_results.save_dir).resolve()), "best_pt": str(best_path.resolve()), "metrics": { "map50": float(train_results.results_dict.get("metrics/mAP50(B)", 0.0)), "map50_95": float(train_results.results_dict.get("metrics/mAP50-95(B)", 0.0)), "precision": float(train_results.results_dict.get("metrics/precision(B)", 0.0)), "recall": float(train_results.results_dict.get("metrics/recall(B)", 0.0)), }, "roi_eval": roi_summary, } compare_summary.append(record) summary_path = Path(train_results.save_dir) / "roi_eval_summary.json" summary_path.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8") compare_json = project / "compare_summary.json" compare_json.write_text(json.dumps(compare_summary, indent=2, ensure_ascii=False), encoding="utf-8") print(json.dumps(record, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()