DetectionModelTraining/08_train_compare_models.py

156 lines
5.4 KiB
Python

#!/usr/bin/env python3
"""Train multiple YOLO models on the merged shoe dataset and compare ROI performance."""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
from statistics import mean
REPO_ROOT = Path(__file__).resolve().parent
PYDEPS = REPO_ROOT / ".pydeps"
if PYDEPS.exists():
sys.path.insert(0, str(PYDEPS))
from ultralytics import YOLO # noqa: E402
DEFAULT_MODELS = ["yolov8s.pt", "yolo11s.pt", "yolo26s.pt"]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train and compare YOLO shoe detectors")
parser.add_argument("--data", default="datasets/shoe-public-mix/data.yaml", help="Training dataset yaml")
parser.add_argument("--roi-dir", default="datasets/roi-shoes", help="Real-world ROI evaluation directory")
parser.add_argument("--project", default="runs/shoe_compare", help="Ultralytics project directory")
parser.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Model checkpoints or aliases")
parser.add_argument("--epochs", type=int, default=40, help="Training epochs for each model")
parser.add_argument("--imgsz", type=int, default=640, help="Training image size")
parser.add_argument("--batch", type=int, default=16, help="Training batch size")
parser.add_argument("--workers", type=int, default=8, help="Data loader workers")
parser.add_argument("--device", default="0", help="Training device, e.g. 0 or cpu")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--patience", type=int, default=20, help="Early stopping patience")
return parser.parse_args()
def safe_name(model_name: str) -> str:
stem = Path(model_name).stem
return stem.replace(".", "_").replace("-", "_")
def evaluate_roi(model_path: Path, roi_dir: Path, save_dir: Path, device: str) -> dict:
model = YOLO(str(model_path))
results = model.predict(
source=str(roi_dir),
conf=0.1,
save=True,
project=str(save_dir.parent),
name=save_dir.name,
exist_ok=True,
verbose=False,
device=device,
)
per_image = []
max_confs = []
hit_count = 0
for result in results:
n = 0 if result.boxes is None else len(result.boxes)
hit = n > 0
if hit:
hit_count += 1
max_confs.append(float(result.boxes.conf.max().item()))
per_image.append(
{
"image": Path(result.path).name,
"detections": n,
"max_conf": float(result.boxes.conf.max().item()) if hit else 0.0,
}
)
summary = {
"roi_total": len(per_image),
"roi_hits": hit_count,
"roi_hit_rate": hit_count / len(per_image) if per_image else 0.0,
"roi_mean_max_conf": mean(max_confs) if max_confs else 0.0,
"per_image": per_image,
}
return summary
def main() -> None:
args = parse_args()
os.environ.setdefault("YOLO_CONFIG_DIR", str(REPO_ROOT / ".ultralytics"))
data = Path(args.data)
roi_dir = Path(args.roi_dir)
project = Path(args.project)
project.mkdir(parents=True, exist_ok=True)
if not data.exists():
raise FileNotFoundError(f"Dataset yaml not found: {data}")
if not roi_dir.exists():
raise FileNotFoundError(f"ROI directory not found: {roi_dir}")
compare_summary = []
for model_name in args.models:
run_name = f"{safe_name(model_name)}_shoe_{args.imgsz}"
print("=" * 80)
print(f"Training {model_name} -> {run_name}")
print("=" * 80)
model = YOLO(model_name)
train_results = model.train(
data=str(data.resolve()),
epochs=args.epochs,
imgsz=args.imgsz,
batch=args.batch,
workers=args.workers,
device=args.device,
seed=args.seed,
patience=args.patience,
project=str(project.resolve()),
name=run_name,
exist_ok=True,
pretrained=True,
cos_lr=True,
amp=True,
verbose=True,
)
best_path = Path(train_results.save_dir) / "weights" / "best.pt"
roi_save_dir = Path(train_results.save_dir) / "roi_eval"
roi_summary = evaluate_roi(best_path, roi_dir, roi_save_dir, args.device)
record = {
"model": model_name,
"run_name": run_name,
"save_dir": str(Path(train_results.save_dir).resolve()),
"best_pt": str(best_path.resolve()),
"metrics": {
"map50": float(train_results.results_dict.get("metrics/mAP50(B)", 0.0)),
"map50_95": float(train_results.results_dict.get("metrics/mAP50-95(B)", 0.0)),
"precision": float(train_results.results_dict.get("metrics/precision(B)", 0.0)),
"recall": float(train_results.results_dict.get("metrics/recall(B)", 0.0)),
},
"roi_eval": roi_summary,
}
compare_summary.append(record)
summary_path = Path(train_results.save_dir) / "roi_eval_summary.json"
summary_path.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
compare_json = project / "compare_summary.json"
compare_json.write_text(json.dumps(compare_summary, indent=2, ensure_ascii=False), encoding="utf-8")
print(json.dumps(record, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()