156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""Train multiple YOLO models on the merged shoe dataset and compare ROI performance."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from statistics import mean
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent
|
|
PYDEPS = REPO_ROOT / ".pydeps"
|
|
if PYDEPS.exists():
|
|
sys.path.insert(0, str(PYDEPS))
|
|
|
|
from ultralytics import YOLO # noqa: E402
|
|
|
|
|
|
DEFAULT_MODELS = ["yolov8s.pt", "yolo11s.pt", "yolo26s.pt"]
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Train and compare YOLO shoe detectors")
|
|
parser.add_argument("--data", default="datasets/shoe-public-mix/data.yaml", help="Training dataset yaml")
|
|
parser.add_argument("--roi-dir", default="datasets/roi-shoes", help="Real-world ROI evaluation directory")
|
|
parser.add_argument("--project", default="runs/shoe_compare", help="Ultralytics project directory")
|
|
parser.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Model checkpoints or aliases")
|
|
parser.add_argument("--epochs", type=int, default=40, help="Training epochs for each model")
|
|
parser.add_argument("--imgsz", type=int, default=640, help="Training image size")
|
|
parser.add_argument("--batch", type=int, default=16, help="Training batch size")
|
|
parser.add_argument("--workers", type=int, default=8, help="Data loader workers")
|
|
parser.add_argument("--device", default="0", help="Training device, e.g. 0 or cpu")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
parser.add_argument("--patience", type=int, default=20, help="Early stopping patience")
|
|
return parser.parse_args()
|
|
|
|
|
|
def safe_name(model_name: str) -> str:
|
|
stem = Path(model_name).stem
|
|
return stem.replace(".", "_").replace("-", "_")
|
|
|
|
|
|
def evaluate_roi(model_path: Path, roi_dir: Path, save_dir: Path, device: str) -> dict:
|
|
model = YOLO(str(model_path))
|
|
results = model.predict(
|
|
source=str(roi_dir),
|
|
conf=0.1,
|
|
save=True,
|
|
project=str(save_dir.parent),
|
|
name=save_dir.name,
|
|
exist_ok=True,
|
|
verbose=False,
|
|
device=device,
|
|
)
|
|
|
|
per_image = []
|
|
max_confs = []
|
|
hit_count = 0
|
|
for result in results:
|
|
n = 0 if result.boxes is None else len(result.boxes)
|
|
hit = n > 0
|
|
if hit:
|
|
hit_count += 1
|
|
max_confs.append(float(result.boxes.conf.max().item()))
|
|
per_image.append(
|
|
{
|
|
"image": Path(result.path).name,
|
|
"detections": n,
|
|
"max_conf": float(result.boxes.conf.max().item()) if hit else 0.0,
|
|
}
|
|
)
|
|
|
|
summary = {
|
|
"roi_total": len(per_image),
|
|
"roi_hits": hit_count,
|
|
"roi_hit_rate": hit_count / len(per_image) if per_image else 0.0,
|
|
"roi_mean_max_conf": mean(max_confs) if max_confs else 0.0,
|
|
"per_image": per_image,
|
|
}
|
|
return summary
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
os.environ.setdefault("YOLO_CONFIG_DIR", str(REPO_ROOT / ".ultralytics"))
|
|
|
|
data = Path(args.data)
|
|
roi_dir = Path(args.roi_dir)
|
|
project = Path(args.project)
|
|
project.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not data.exists():
|
|
raise FileNotFoundError(f"Dataset yaml not found: {data}")
|
|
if not roi_dir.exists():
|
|
raise FileNotFoundError(f"ROI directory not found: {roi_dir}")
|
|
|
|
compare_summary = []
|
|
for model_name in args.models:
|
|
run_name = f"{safe_name(model_name)}_shoe_{args.imgsz}"
|
|
print("=" * 80)
|
|
print(f"Training {model_name} -> {run_name}")
|
|
print("=" * 80)
|
|
|
|
model = YOLO(model_name)
|
|
train_results = model.train(
|
|
data=str(data.resolve()),
|
|
epochs=args.epochs,
|
|
imgsz=args.imgsz,
|
|
batch=args.batch,
|
|
workers=args.workers,
|
|
device=args.device,
|
|
seed=args.seed,
|
|
patience=args.patience,
|
|
project=str(project.resolve()),
|
|
name=run_name,
|
|
exist_ok=True,
|
|
pretrained=True,
|
|
cos_lr=True,
|
|
amp=True,
|
|
verbose=True,
|
|
)
|
|
|
|
best_path = Path(train_results.save_dir) / "weights" / "best.pt"
|
|
roi_save_dir = Path(train_results.save_dir) / "roi_eval"
|
|
roi_summary = evaluate_roi(best_path, roi_dir, roi_save_dir, args.device)
|
|
|
|
record = {
|
|
"model": model_name,
|
|
"run_name": run_name,
|
|
"save_dir": str(Path(train_results.save_dir).resolve()),
|
|
"best_pt": str(best_path.resolve()),
|
|
"metrics": {
|
|
"map50": float(train_results.results_dict.get("metrics/mAP50(B)", 0.0)),
|
|
"map50_95": float(train_results.results_dict.get("metrics/mAP50-95(B)", 0.0)),
|
|
"precision": float(train_results.results_dict.get("metrics/precision(B)", 0.0)),
|
|
"recall": float(train_results.results_dict.get("metrics/recall(B)", 0.0)),
|
|
},
|
|
"roi_eval": roi_summary,
|
|
}
|
|
compare_summary.append(record)
|
|
|
|
summary_path = Path(train_results.save_dir) / "roi_eval_summary.json"
|
|
summary_path.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
|
|
compare_json = project / "compare_summary.json"
|
|
compare_json.write_text(json.dumps(compare_summary, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
|
|
print(json.dumps(record, indent=2, ensure_ascii=False))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|