Add RKNN ROI shoe evaluation script
This commit is contained in:
parent
58bb66e66e
commit
1dc27e6ddd
346
scripts/eval_rknn_roi_shoes.py
Normal file
346
scripts/eval_rknn_roi_shoes.py
Normal file
@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate RKNN shoe detector recall on positive ROI crops.
|
||||
|
||||
This script is intended to run on RK3588.
|
||||
|
||||
Usage:
|
||||
python3 scripts/eval_rknn_roi_shoes.py
|
||||
python3 scripts/eval_rknn_roi_shoes.py --model models/yolov8s_shoe_640-rk3588.rknn
|
||||
python3 scripts/eval_rknn_roi_shoes.py --conf-list 0.15,0.22,0.25,0.35 --save-dir train/roi-shoes-vis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from rknnlite.api import RKNNLite # type: ignore
|
||||
_BACKEND = "rknnlite"
|
||||
except Exception:
|
||||
RKNNLite = None
|
||||
try:
|
||||
from rknn.api import RKNN # type: ignore
|
||||
_BACKEND = "rknn"
|
||||
except Exception:
|
||||
RKNN = None
|
||||
_BACKEND = ""
|
||||
|
||||
|
||||
DEFAULT_MODEL = "models/yolov8s_shoe_640-rk3588.rknn"
|
||||
DEFAULT_CONF_LIST = [0.15, 0.22, 0.25, 0.35]
|
||||
DEFAULT_IMG_DIRS = ["train/roi-shoes", "train/roi_shoes"]
|
||||
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Det:
|
||||
x: float
|
||||
y: float
|
||||
w: float
|
||||
h: float
|
||||
conf: float
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="Evaluate RKNN shoe recall on ROI positives")
|
||||
p.add_argument("--model", default=DEFAULT_MODEL, help="Path to .rknn model")
|
||||
p.add_argument("--img-dir", default="", help="Directory containing ROI shoe images")
|
||||
p.add_argument("--imgsz", type=int, default=640, help="Model input size")
|
||||
p.add_argument("--box-format", default="cxcywh", choices=["cxcywh", "xywh", "xyxy"])
|
||||
p.add_argument("--nms", type=float, default=0.45, help="NMS IoU threshold")
|
||||
p.add_argument(
|
||||
"--conf-list",
|
||||
default=",".join(str(x) for x in DEFAULT_CONF_LIST),
|
||||
help="Comma-separated confidence thresholds",
|
||||
)
|
||||
p.add_argument("--save-dir", default="", help="Optional directory to save visualized detections")
|
||||
p.add_argument("--save-threshold", type=float, default=0.22, help="Confidence used for visualization saving")
|
||||
p.add_argument("--limit", type=int, default=0, help="Only evaluate the first N images")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def resolve_img_dir(user_value: str) -> Path:
|
||||
if user_value:
|
||||
p = Path(user_value)
|
||||
if not p.is_dir():
|
||||
raise FileNotFoundError(f"image directory not found: {p}")
|
||||
return p
|
||||
|
||||
for item in DEFAULT_IMG_DIRS:
|
||||
p = Path(item)
|
||||
if p.is_dir():
|
||||
return p
|
||||
raise FileNotFoundError(
|
||||
"image directory not found. tried: " + ", ".join(DEFAULT_IMG_DIRS)
|
||||
)
|
||||
|
||||
|
||||
def iter_images(img_dir: Path) -> list[Path]:
|
||||
files = [p for p in sorted(img_dir.iterdir()) if p.is_file() and p.suffix.lower() in IMG_EXTS]
|
||||
return files
|
||||
|
||||
|
||||
def load_backend(model_path: str):
|
||||
if _BACKEND == "rknnlite":
|
||||
inst = RKNNLite()
|
||||
ret = inst.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"load_rknn failed: ret={ret}")
|
||||
core_auto = getattr(RKNNLite, "NPU_CORE_AUTO", 0)
|
||||
ret = inst.init_runtime(core_mask=core_auto)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"init_runtime failed: ret={ret}")
|
||||
return inst
|
||||
|
||||
if _BACKEND == "rknn":
|
||||
inst = RKNN(verbose=False)
|
||||
ret = inst.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"load_rknn failed: ret={ret}")
|
||||
ret = inst.init_runtime()
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"init_runtime failed: ret={ret}")
|
||||
return inst
|
||||
|
||||
raise RuntimeError("Neither rknnlite.api nor rknn.api is available")
|
||||
|
||||
|
||||
def release_backend(inst) -> None:
|
||||
try:
|
||||
inst.release()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def preprocess_bgr(img_bgr: np.ndarray, imgsz: int) -> np.ndarray:
|
||||
resized = cv2.resize(img_bgr, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR)
|
||||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||||
return np.expand_dims(rgb, axis=0)
|
||||
|
||||
|
||||
def infer(inst, inp: np.ndarray):
|
||||
outputs = inst.inference(inputs=[inp], data_format=["nhwc"])
|
||||
if not outputs:
|
||||
raise RuntimeError("empty inference outputs")
|
||||
return outputs
|
||||
|
||||
|
||||
def to_2d_output(output: np.ndarray) -> np.ndarray:
|
||||
arr = np.asarray(output)
|
||||
if arr.ndim == 3 and arr.shape[0] == 1:
|
||||
arr = arr[0]
|
||||
if arr.ndim != 2:
|
||||
raise RuntimeError(f"unexpected output shape: {tuple(np.asarray(output).shape)}")
|
||||
return arr
|
||||
|
||||
|
||||
def decode_boxes(raw: np.ndarray, conf_thresh: float, box_format: str, imgsz: int) -> list[Det]:
|
||||
arr = to_2d_output(raw)
|
||||
|
||||
if arr.shape[0] == 5:
|
||||
arr = arr.T # [8400, 5]
|
||||
elif arr.shape[1] == 5:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"unsupported YOLO output shape: {tuple(arr.shape)}")
|
||||
|
||||
dets: list[Det] = []
|
||||
for row in arr:
|
||||
a, b, c, d, score = [float(x) for x in row[:5]]
|
||||
conf = 1.0 / (1.0 + math.exp(-score)) if (score < -0.1 or score > 1.5) else score
|
||||
if conf < conf_thresh:
|
||||
continue
|
||||
|
||||
if max(abs(a), abs(b), abs(c), abs(d)) <= 2.5:
|
||||
# normalized
|
||||
a *= imgsz
|
||||
b *= imgsz
|
||||
c *= imgsz
|
||||
d *= imgsz
|
||||
|
||||
if box_format == "cxcywh":
|
||||
x = a - c / 2.0
|
||||
y = b - d / 2.0
|
||||
w = c
|
||||
h = d
|
||||
elif box_format == "xywh":
|
||||
x = a
|
||||
y = b
|
||||
w = c
|
||||
h = d
|
||||
else:
|
||||
x = a
|
||||
y = b
|
||||
w = c - a
|
||||
h = d - b
|
||||
|
||||
if w <= 1e-3 or h <= 1e-3:
|
||||
continue
|
||||
if w > imgsz * 1.2 or h > imgsz * 1.2:
|
||||
continue
|
||||
dets.append(Det(x=x, y=y, w=w, h=h, conf=conf))
|
||||
return dets
|
||||
|
||||
|
||||
def iou(a: Det, b: Det) -> float:
|
||||
x1 = max(a.x, b.x)
|
||||
y1 = max(a.y, b.y)
|
||||
x2 = min(a.x + a.w, b.x + b.w)
|
||||
y2 = min(a.y + a.h, b.y + b.h)
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
inter = (x2 - x1) * (y2 - y1)
|
||||
area_a = a.w * a.h
|
||||
area_b = b.w * b.h
|
||||
union = area_a + area_b - inter
|
||||
return inter / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def apply_nms(dets: Iterable[Det], nms_thresh: float) -> list[Det]:
|
||||
items = sorted(dets, key=lambda x: x.conf, reverse=True)
|
||||
keep: list[Det] = []
|
||||
for det in items:
|
||||
if any(iou(det, kept) > nms_thresh for kept in keep):
|
||||
continue
|
||||
keep.append(det)
|
||||
return keep
|
||||
|
||||
|
||||
def scale_back(det: Det, src_w: int, src_h: int, imgsz: int) -> Det:
|
||||
sx = src_w / float(imgsz)
|
||||
sy = src_h / float(imgsz)
|
||||
return Det(
|
||||
x=det.x * sx,
|
||||
y=det.y * sy,
|
||||
w=det.w * sx,
|
||||
h=det.h * sy,
|
||||
conf=det.conf,
|
||||
)
|
||||
|
||||
|
||||
def draw_and_save(img_bgr: np.ndarray, dets: list[Det], out_path: Path) -> None:
|
||||
vis = img_bgr.copy()
|
||||
for det in dets:
|
||||
x1 = max(0, int(round(det.x)))
|
||||
y1 = max(0, int(round(det.y)))
|
||||
x2 = max(x1 + 1, int(round(det.x + det.w)))
|
||||
y2 = max(y1 + 1, int(round(det.y + det.h)))
|
||||
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
vis,
|
||||
f"{det.conf:.2f}",
|
||||
(x1, max(12, y1 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.45,
|
||||
(0, 255, 0),
|
||||
1,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cv2.imwrite(str(out_path), vis)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
img_dir = resolve_img_dir(args.img_dir)
|
||||
model_path = Path(args.model)
|
||||
if not model_path.is_file():
|
||||
raise FileNotFoundError(f"model not found: {model_path}")
|
||||
|
||||
conf_list = [float(x.strip()) for x in args.conf_list.split(",") if x.strip()]
|
||||
if not conf_list:
|
||||
raise ValueError("conf-list is empty")
|
||||
|
||||
image_paths = iter_images(img_dir)
|
||||
if args.limit > 0:
|
||||
image_paths = image_paths[: args.limit]
|
||||
if not image_paths:
|
||||
raise RuntimeError(f"no images found in {img_dir}")
|
||||
|
||||
inst = load_backend(str(model_path))
|
||||
try:
|
||||
per_image_best: list[dict] = []
|
||||
save_dir = Path(args.save_dir) if args.save_dir else None
|
||||
|
||||
for idx, img_path in enumerate(image_paths, start=1):
|
||||
img_bgr = cv2.imread(str(img_path))
|
||||
if img_bgr is None:
|
||||
print(f"[WARN] failed to read image: {img_path}")
|
||||
continue
|
||||
|
||||
inp = preprocess_bgr(img_bgr, args.imgsz)
|
||||
outputs = infer(inst, inp)
|
||||
raw_dets = decode_boxes(outputs[0], conf_thresh=0.0, box_format=args.box_format, imgsz=args.imgsz)
|
||||
raw_dets = apply_nms(raw_dets, args.nms)
|
||||
scaled = [scale_back(d, img_bgr.shape[1], img_bgr.shape[0], args.imgsz) for d in raw_dets]
|
||||
best_conf = max((d.conf for d in scaled), default=0.0)
|
||||
per_image_best.append(
|
||||
{
|
||||
"file": img_path.name,
|
||||
"best_conf": round(best_conf, 4),
|
||||
"num_boxes": len(scaled),
|
||||
}
|
||||
)
|
||||
|
||||
if save_dir is not None:
|
||||
dets_for_vis = [d for d in scaled if d.conf >= args.save_threshold]
|
||||
draw_and_save(img_bgr, dets_for_vis, save_dir / img_path.name)
|
||||
|
||||
print(f"[{idx:02d}/{len(image_paths)}] {img_path.name}: boxes={len(scaled)} best_conf={best_conf:.4f}")
|
||||
|
||||
total = len(per_image_best)
|
||||
print("\n=== Summary ===")
|
||||
print(f"model: {model_path}")
|
||||
print(f"images: {img_dir} ({total})")
|
||||
print(f"nms: {args.nms:.2f}")
|
||||
print("")
|
||||
print(f"{'conf':>6} {'hits':>8} {'hit_rate':>10} {'mean_max_conf':>14}")
|
||||
results = []
|
||||
best_values = np.array([x["best_conf"] for x in per_image_best], dtype=np.float32)
|
||||
for conf in conf_list:
|
||||
hits = int(np.sum(best_values >= conf))
|
||||
hit_rate = hits / total if total else 0.0
|
||||
mean_max_conf = float(np.mean(best_values)) if total else 0.0
|
||||
print(f"{conf:>6.2f} {hits:>3d}/{total:<4d} {hit_rate:>9.4f} {mean_max_conf:>14.4f}")
|
||||
missed = [x["file"] for x in per_image_best if x["best_conf"] < conf]
|
||||
results.append(
|
||||
{
|
||||
"conf": conf,
|
||||
"hits": hits,
|
||||
"total": total,
|
||||
"hit_rate": round(hit_rate, 4),
|
||||
"mean_max_conf": round(mean_max_conf, 4),
|
||||
"missed_files": missed,
|
||||
}
|
||||
)
|
||||
|
||||
report = {
|
||||
"model": str(model_path),
|
||||
"images_dir": str(img_dir),
|
||||
"imgsz": args.imgsz,
|
||||
"box_format": args.box_format,
|
||||
"nms": args.nms,
|
||||
"results": results,
|
||||
"per_image": per_image_best,
|
||||
}
|
||||
report_path = img_dir / f"{model_path.stem}_roi_eval.json"
|
||||
report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\nreport saved: {report_path}")
|
||||
if save_dir is not None:
|
||||
print(f"visualizations saved: {save_dir}")
|
||||
return 0
|
||||
finally:
|
||||
release_backend(inst)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Loading…
Reference in New Issue
Block a user