95 lines
2.6 KiB
Python
95 lines
2.6 KiB
Python
import argparse
|
|
import time
|
|
|
|
import cv2
|
|
from ultralytics import YOLO
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Test YOLO .pt model with USB camera")
|
|
parser.add_argument("--model", type=str, required=True, help="Path to .pt model")
|
|
parser.add_argument("--camera", type=int, default=0, help="USB camera index")
|
|
parser.add_argument("--imgsz", type=int, default=768, help="Inference image size")
|
|
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
|
|
parser.add_argument("--iou", type=float, default=0.6, help="NMS IoU threshold")
|
|
parser.add_argument(
|
|
"--classes",
|
|
type=str,
|
|
default="",
|
|
help="Optional class ids, e.g. '3,6' for boots+Person. Empty means all classes.",
|
|
)
|
|
parser.add_argument("--device", type=str, default="0", help="CUDA device id, e.g. 0, or 'cpu'")
|
|
parser.add_argument("--line-width", type=int, default=2, help="Box line width")
|
|
return parser.parse_args()
|
|
|
|
|
|
def parse_classes(raw: str):
|
|
raw = raw.strip()
|
|
if not raw:
|
|
return None
|
|
ids = []
|
|
for x in raw.split(","):
|
|
x = x.strip()
|
|
if x:
|
|
ids.append(int(x))
|
|
return ids if ids else None
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
classes = parse_classes(args.classes)
|
|
|
|
model = YOLO(args.model)
|
|
|
|
cap = cv2.VideoCapture(args.camera)
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Cannot open camera index {args.camera}")
|
|
|
|
prev_t = time.time()
|
|
|
|
while True:
|
|
ok, frame = cap.read()
|
|
if not ok:
|
|
print("Failed to read frame from camera")
|
|
break
|
|
|
|
results = model.predict(
|
|
source=frame,
|
|
imgsz=args.imgsz,
|
|
conf=args.conf,
|
|
iou=args.iou,
|
|
classes=classes,
|
|
device=args.device,
|
|
verbose=False,
|
|
)
|
|
|
|
plotted = results[0].plot(line_width=args.line_width)
|
|
|
|
now = time.time()
|
|
fps = 1.0 / max(now - prev_t, 1e-6)
|
|
prev_t = now
|
|
|
|
cls_info = "all" if classes is None else str(classes)
|
|
cv2.putText(
|
|
plotted,
|
|
f"FPS: {fps:.1f} | conf={args.conf:.2f} iou={args.iou:.2f} classes={cls_info}",
|
|
(10, 30),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
0.8,
|
|
(0, 255, 0),
|
|
2,
|
|
cv2.LINE_AA,
|
|
)
|
|
|
|
cv2.imshow("YOLO PT Webcam Test", plotted)
|
|
key = cv2.waitKey(1) & 0xFF
|
|
if key in (27, ord("q")):
|
|
break
|
|
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|