Add ROI-based shoe training workflow
This commit is contained in:
parent
45f1f3182f
commit
22aee7fa1e
41
02_train.bat
41
02_train.bat
@ -2,8 +2,34 @@
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
:: 设置 Python 3.11 路径
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
set "REPO_DIR=%~dp0"
|
||||
pushd "%REPO_DIR%"
|
||||
|
||||
set "YOLO_LAUNCHER="
|
||||
where yolo >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=yolo"
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
py -3.11 -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=py -3.11 -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
python -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=python -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
echo [错误] 未找到可用的 Ultralytics 启动方式
|
||||
echo.
|
||||
echo 请先确保满足以下任一条件:
|
||||
echo 1. yolo 命令已加入 PATH
|
||||
echo 2. py -3.11 可运行并已安装 ultralytics
|
||||
echo 3. python 可运行并已安装 ultralytics
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo ============================================================
|
||||
echo 训练鞋子检测模型 (YOLOv8 + 640x640)
|
||||
@ -11,14 +37,15 @@ echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 设置数据集路径
|
||||
set DATASET=datasets/openimages-shoes-yolo/data.yaml
|
||||
set "DATASET=%REPO_DIR%datasets\openimages-shoes-yolo\data.yaml"
|
||||
|
||||
:: 检查数据集是否存在
|
||||
if not exist %DATASET% (
|
||||
if not exist "%DATASET%" (
|
||||
echo [错误] 找不到数据集配置文件: %DATASET%
|
||||
echo.
|
||||
echo 请先下载数据集:
|
||||
echo python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
echo py -3.11 "%REPO_DIR%01_download_dataset.py" --source openimages --max-samples 5000
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
@ -69,11 +96,12 @@ echo 开始训练
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
yolo detect train data=%DATASET% model=%MODEL% epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
call %YOLO_LAUNCHER% detect train data="%DATASET%" model="%MODEL%" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [错误] 训练失败!
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
@ -87,4 +115,5 @@ echo 模型保存在: runs/detect/train/weights/best.pt
|
||||
echo.
|
||||
echo 下一步: 运行 03_export_onnx.bat 导出 ONNX
|
||||
echo.
|
||||
popd
|
||||
pause
|
||||
|
||||
@ -2,19 +2,46 @@
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
:: 设置 Python 3.11 路径
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
set "REPO_DIR=%~dp0"
|
||||
pushd "%REPO_DIR%"
|
||||
|
||||
set "YOLO_LAUNCHER="
|
||||
where yolo >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=yolo"
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
py -3.11 -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=py -3.11 -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
python -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=python -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
echo [错误] 未找到可用的 Ultralytics 启动方式
|
||||
echo.
|
||||
echo 请先确保满足以下任一条件:
|
||||
echo 1. yolo 命令已加入 PATH
|
||||
echo 2. py -3.11 可运行并已安装 ultralytics
|
||||
echo 3. python 可运行并已安装 ultralytics
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo ============================================================
|
||||
echo 导出 ONNX 模型 (640x640)
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
set MODEL_PATH=%USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.pt
|
||||
set "MODEL_PATH=%REPO_DIR%runs\detect\train\weights\best.pt"
|
||||
|
||||
if not exist %MODEL_PATH% (
|
||||
if not exist "%MODEL_PATH%" (
|
||||
echo [错误] 找不到模型: %MODEL_PATH%
|
||||
echo 请先运行 02_train.bat 训练
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
@ -22,17 +49,19 @@ if not exist %MODEL_PATH% (
|
||||
echo [信息] 输入模型: %MODEL_PATH%
|
||||
echo.
|
||||
|
||||
yolo export model=%MODEL_PATH% format=onnx imgsz=640 opset=12 simplify
|
||||
call %YOLO_LAUNCHER% export model="%MODEL_PATH%" format=onnx imgsz=640 opset=12 simplify
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo [错误] 导出失败!
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo [成功] ONNX 模型: %USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.onnx
|
||||
echo [成功] ONNX 模型: %REPO_DIR%runs\detect\train\weights\best.onnx
|
||||
echo.
|
||||
echo 下一步: 在 Ubuntu 上运行 04_convert_rknn.py 转换
|
||||
echo.
|
||||
popd
|
||||
pause
|
||||
|
||||
@ -1,26 +1,56 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
set "REPO_DIR=%~dp0"
|
||||
pushd "%REPO_DIR%"
|
||||
|
||||
set "YOLO_LAUNCHER="
|
||||
where yolo >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=yolo"
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
py -3.11 -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=py -3.11 -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
python -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=python -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
echo [ERROR] No usable Ultralytics launcher was found
|
||||
echo.
|
||||
echo Please make sure one of the following works:
|
||||
echo 1. yolo in PATH
|
||||
echo 2. py -3.11 with ultralytics installed
|
||||
echo 3. python with ultralytics installed
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo ============================================================
|
||||
echo Stage 2 Fine-tuning on PPE shoe subset
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
set "DATASET=C:\Users\Tellme\apps\ppe-model-training\datasets\ppe-shoes\data.yaml"
|
||||
set "BASE_MODEL=C:\Users\Tellme\apps\ultralytics\runs\detect\train3\weights\best.pt"
|
||||
set "DATASET=%REPO_DIR%datasets\ppe-shoes\data.yaml"
|
||||
set "BASE_MODEL=%REPO_DIR%runs\detect\train\weights\best.pt"
|
||||
|
||||
if not exist "%DATASET%" (
|
||||
echo [ERROR] PPE shoe subset not found: %DATASET%
|
||||
echo Run:
|
||||
echo C:\Users\Tellme\AppData\Local\Programs\Python\Python311\python.exe C:\Users\Tellme\apps\ppe-model-training\05_prepare_ppe_shoe_subset.py
|
||||
echo python "%REPO_DIR%05_prepare_ppe_shoe_subset.py"
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if not exist "%BASE_MODEL%" (
|
||||
echo [ERROR] Base model not found: %BASE_MODEL%
|
||||
echo Run 02_train.bat first to create the base checkpoint.
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
@ -45,11 +75,12 @@ echo Start fine-tuning
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
yolo detect train data="%DATASET%" model="%BASE_MODEL%" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
call %YOLO_LAUNCHER% detect train data="%DATASET%" model="%BASE_MODEL%" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [ERROR] Fine-tuning failed.
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
@ -59,4 +90,5 @@ echo ============================================================
|
||||
echo Fine-tuning complete
|
||||
echo ============================================================
|
||||
echo.
|
||||
popd
|
||||
pause
|
||||
|
||||
456
09_build_roi_shoe_dataset.py
Normal file
456
09_build_roi_shoe_dataset.py
Normal file
@ -0,0 +1,456 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a foot-ROI shoe dataset from existing YOLO shoe datasets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
DEFAULT_SOURCES = [
|
||||
"datasets/openimages-shoes-yolo",
|
||||
"datasets/ppe-shoes",
|
||||
]
|
||||
|
||||
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
|
||||
PAIR_MAX_X_GAP_FACTOR = 3.2
|
||||
PAIR_MAX_Y_GAP_FACTOR = 1.2
|
||||
PAIR_MIN_AREA_RATIO = 0.4
|
||||
PAIR_MAX_AREA_RATIO = 2.5
|
||||
|
||||
SINGLE_AREA_RANGE = (0.15, 0.35)
|
||||
PAIR_AREA_RANGE = (0.25, 0.50)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Box:
|
||||
x1: float
|
||||
y1: float
|
||||
x2: float
|
||||
y2: float
|
||||
|
||||
@property
|
||||
def w(self) -> float:
|
||||
return max(0.0, self.x2 - self.x1)
|
||||
|
||||
@property
|
||||
def h(self) -> float:
|
||||
return max(0.0, self.y2 - self.y1)
|
||||
|
||||
@property
|
||||
def area(self) -> float:
|
||||
return self.w * self.h
|
||||
|
||||
@property
|
||||
def cx(self) -> float:
|
||||
return (self.x1 + self.x2) / 2.0
|
||||
|
||||
@property
|
||||
def cy(self) -> float:
|
||||
return (self.y1 + self.y2) / 2.0
|
||||
|
||||
def clip(self, width: float, height: float) -> "Box | None":
|
||||
x1 = min(max(self.x1, 0.0), width)
|
||||
y1 = min(max(self.y1, 0.0), height)
|
||||
x2 = min(max(self.x2, 0.0), width)
|
||||
y2 = min(max(self.y2, 0.0), height)
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
return Box(x1, y1, x2, y2)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RoiSample:
|
||||
roi: Box
|
||||
members: tuple[int, ...]
|
||||
mode: str
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Build a foot-context ROI shoe dataset")
|
||||
parser.add_argument(
|
||||
"--sources",
|
||||
nargs="+",
|
||||
default=DEFAULT_SOURCES,
|
||||
help="Source YOLO datasets containing images/<split> and labels/<split>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="datasets/shoe-roi-mix",
|
||||
help="Output ROI dataset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean",
|
||||
action="store_true",
|
||||
help="Delete the output directory before rebuilding",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def ensure_output_layout(output_dir: Path) -> None:
|
||||
for split in ("train", "val", "test"):
|
||||
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def find_image(image_dir: Path, stem: str) -> Path | None:
|
||||
for ext in IMAGE_EXTS:
|
||||
candidate = image_dir / f"{stem}{ext}"
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def load_boxes(label_path: Path, image_width: int, image_height: int) -> list[Box]:
|
||||
boxes: list[Box] = []
|
||||
for raw_line in label_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
_, xc, yc, w, h = parts[:5]
|
||||
box_w = float(w) * image_width
|
||||
box_h = float(h) * image_height
|
||||
center_x = float(xc) * image_width
|
||||
center_y = float(yc) * image_height
|
||||
box = Box(
|
||||
center_x - box_w / 2.0,
|
||||
center_y - box_h / 2.0,
|
||||
center_x + box_w / 2.0,
|
||||
center_y + box_h / 2.0,
|
||||
).clip(image_width, image_height)
|
||||
if box is not None and box.area > 1.0:
|
||||
boxes.append(box)
|
||||
return dedupe_boxes(boxes)
|
||||
|
||||
|
||||
def dedupe_boxes(boxes: list[Box], iou_threshold: float = 0.9) -> list[Box]:
|
||||
kept: list[Box] = []
|
||||
for box in sorted(boxes, key=lambda item: item.area, reverse=True):
|
||||
if any(iou(box, existing) >= iou_threshold for existing in kept):
|
||||
continue
|
||||
kept.append(box)
|
||||
return sorted(kept, key=lambda item: (item.cx, item.cy))
|
||||
|
||||
|
||||
def iou(a: Box, b: Box) -> float:
|
||||
inter_x1 = max(a.x1, b.x1)
|
||||
inter_y1 = max(a.y1, b.y1)
|
||||
inter_x2 = min(a.x2, b.x2)
|
||||
inter_y2 = min(a.y2, b.y2)
|
||||
inter_w = max(0.0, inter_x2 - inter_x1)
|
||||
inter_h = max(0.0, inter_y2 - inter_y1)
|
||||
inter_area = inter_w * inter_h
|
||||
if inter_area <= 0:
|
||||
return 0.0
|
||||
union = a.area + b.area - inter_area
|
||||
return inter_area / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def should_pair(left: Box, right: Box) -> bool:
|
||||
width_ref = max(left.w, right.w)
|
||||
height_ref = max(left.h, right.h)
|
||||
if width_ref <= 0 or height_ref <= 0:
|
||||
return False
|
||||
|
||||
dx = abs(left.cx - right.cx)
|
||||
dy = abs(left.cy - right.cy)
|
||||
area_ratio = left.area / right.area if right.area > 0 else math.inf
|
||||
|
||||
return (
|
||||
dx <= width_ref * PAIR_MAX_X_GAP_FACTOR
|
||||
and dy <= height_ref * PAIR_MAX_Y_GAP_FACTOR
|
||||
and PAIR_MIN_AREA_RATIO <= area_ratio <= PAIR_MAX_AREA_RATIO
|
||||
)
|
||||
|
||||
|
||||
def greedy_group_boxes(boxes: list[Box]) -> list[tuple[int, ...]]:
|
||||
if len(boxes) < 2:
|
||||
return [(idx,) for idx in range(len(boxes))]
|
||||
|
||||
candidates: list[tuple[float, int, int]] = []
|
||||
for i in range(len(boxes)):
|
||||
for j in range(i + 1, len(boxes)):
|
||||
if not should_pair(boxes[i], boxes[j]):
|
||||
continue
|
||||
dx = abs(boxes[i].cx - boxes[j].cx)
|
||||
dy = abs(boxes[i].cy - boxes[j].cy)
|
||||
score = dx + (0.5 * dy)
|
||||
candidates.append((score, i, j))
|
||||
|
||||
used: set[int] = set()
|
||||
groups: list[tuple[int, ...]] = []
|
||||
for _, i, j in sorted(candidates, key=lambda item: item[0]):
|
||||
if i in used or j in used:
|
||||
continue
|
||||
used.add(i)
|
||||
used.add(j)
|
||||
groups.append((i, j))
|
||||
|
||||
for idx in range(len(boxes)):
|
||||
if idx not in used:
|
||||
groups.append((idx,))
|
||||
|
||||
return groups
|
||||
|
||||
|
||||
def expand_single(box: Box) -> Box:
|
||||
return Box(
|
||||
box.x1 - (0.6 * box.w),
|
||||
box.y1 - (0.5 * box.h),
|
||||
box.x1 - (0.6 * box.w) + (2.2 * box.w),
|
||||
box.y1 - (0.5 * box.h) + (2.4 * box.h),
|
||||
)
|
||||
|
||||
|
||||
def expand_pair(boxes: list[Box], group: tuple[int, int]) -> Box:
|
||||
first = boxes[group[0]]
|
||||
second = boxes[group[1]]
|
||||
union_x1 = min(first.x1, second.x1)
|
||||
union_y1 = min(first.y1, second.y1)
|
||||
union_x2 = max(first.x2, second.x2)
|
||||
union_y2 = max(first.y2, second.y2)
|
||||
union_w = union_x2 - union_x1
|
||||
union_h = union_y2 - union_y1
|
||||
roi_x = union_x1 - (0.35 * union_w)
|
||||
roi_y = union_y1 - (0.45 * union_h)
|
||||
return Box(
|
||||
roi_x,
|
||||
roi_y,
|
||||
roi_x + (1.7 * union_w),
|
||||
roi_y + (2.0 * union_h),
|
||||
)
|
||||
|
||||
|
||||
def clamp_roi(roi: Box, image_width: int, image_height: int) -> Box | None:
|
||||
clipped = roi.clip(float(image_width), float(image_height))
|
||||
if clipped is None:
|
||||
return None
|
||||
|
||||
x1 = int(math.floor(clipped.x1))
|
||||
y1 = int(math.floor(clipped.y1))
|
||||
x2 = int(math.ceil(clipped.x2))
|
||||
y2 = int(math.ceil(clipped.y2))
|
||||
|
||||
x1 = max(0, min(x1, image_width - 1))
|
||||
y1 = max(0, min(y1, image_height - 1))
|
||||
x2 = max(x1 + 1, min(x2, image_width))
|
||||
y2 = max(y1 + 1, min(y2, image_height))
|
||||
return Box(float(x1), float(y1), float(x2), float(y2))
|
||||
|
||||
|
||||
def resize_roi_to_ratio(
|
||||
roi: Box,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
object_area: float,
|
||||
min_ratio: float,
|
||||
max_ratio: float,
|
||||
) -> Box | None:
|
||||
if object_area <= 0:
|
||||
return None
|
||||
|
||||
adjusted = roi
|
||||
target_ratio = (min_ratio + max_ratio) / 2.0
|
||||
for _ in range(3):
|
||||
roi_area = adjusted.area
|
||||
if roi_area <= 0:
|
||||
return None
|
||||
ratio = object_area / roi_area
|
||||
if min_ratio <= ratio <= max_ratio:
|
||||
break
|
||||
|
||||
scale = math.sqrt(ratio / target_ratio)
|
||||
if ratio < min_ratio:
|
||||
scale = max(0.6, min(0.95, scale))
|
||||
else:
|
||||
scale = min(1.8, max(1.05, scale))
|
||||
|
||||
new_w = adjusted.w * scale
|
||||
new_h = adjusted.h * scale
|
||||
cx = adjusted.cx
|
||||
cy = adjusted.cy
|
||||
adjusted = Box(cx - new_w / 2.0, cy - new_h / 2.0, cx + new_w / 2.0, cy + new_h / 2.0)
|
||||
adjusted = clamp_roi(adjusted, image_width, image_height)
|
||||
if adjusted is None:
|
||||
return None
|
||||
|
||||
return adjusted
|
||||
|
||||
|
||||
def boxes_in_roi(boxes: list[Box], roi: Box) -> list[Box]:
|
||||
included: list[Box] = []
|
||||
for box in boxes:
|
||||
if not (roi.x1 <= box.cx <= roi.x2 and roi.y1 <= box.cy <= roi.y2):
|
||||
continue
|
||||
clipped = Box(
|
||||
box.x1 - roi.x1,
|
||||
box.y1 - roi.y1,
|
||||
box.x2 - roi.x1,
|
||||
box.y2 - roi.y1,
|
||||
).clip(roi.w, roi.h)
|
||||
if clipped is not None and clipped.area > 4.0:
|
||||
included.append(clipped)
|
||||
return included
|
||||
|
||||
|
||||
def make_roi_samples(boxes: list[Box], image_width: int, image_height: int) -> list[RoiSample]:
|
||||
samples: list[RoiSample] = []
|
||||
groups = greedy_group_boxes(boxes)
|
||||
for group in groups:
|
||||
if len(group) == 2:
|
||||
roi = expand_pair(boxes, group)
|
||||
area_range = PAIR_AREA_RANGE
|
||||
mode = "pair"
|
||||
else:
|
||||
roi = expand_single(boxes[group[0]])
|
||||
area_range = SINGLE_AREA_RANGE
|
||||
mode = "single"
|
||||
|
||||
roi = clamp_roi(roi, image_width, image_height)
|
||||
if roi is None:
|
||||
continue
|
||||
|
||||
object_area = sum(boxes[idx].area for idx in group)
|
||||
roi = resize_roi_to_ratio(roi, image_width, image_height, object_area, *area_range)
|
||||
if roi is None:
|
||||
continue
|
||||
samples.append(RoiSample(roi=roi, members=group, mode=mode))
|
||||
return samples
|
||||
|
||||
|
||||
def to_yolo_lines(boxes: list[Box], roi_w: float, roi_h: float) -> list[str]:
|
||||
lines: list[str] = []
|
||||
for box in boxes:
|
||||
xc = ((box.x1 + box.x2) / 2.0) / roi_w
|
||||
yc = ((box.y1 + box.y2) / 2.0) / roi_h
|
||||
bw = box.w / roi_w
|
||||
bh = box.h / roi_h
|
||||
lines.append(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}")
|
||||
return lines
|
||||
|
||||
|
||||
def write_yaml(output_dir: Path, sources: list[str]) -> None:
|
||||
yaml_path = output_dir / "data.yaml"
|
||||
source_names = ", ".join(Path(item).name for item in sources)
|
||||
yaml_path.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"# ROI shoe training mix",
|
||||
"",
|
||||
f"path: {output_dir.resolve().as_posix()}",
|
||||
"train: images/train",
|
||||
"val: images/val",
|
||||
"test: images/test",
|
||||
"",
|
||||
"nc: 1",
|
||||
"names: ['shoe']",
|
||||
"",
|
||||
"dataset_info:",
|
||||
" name: shoe-roi-mix",
|
||||
" task: detect_shoe_roi",
|
||||
f" source: {source_names}",
|
||||
" note: cropped to foot-context ROIs to match online two-stage inference",
|
||||
"",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def build_split(source_dir: Path, output_dir: Path, split: str) -> dict[str, int]:
|
||||
image_dir = source_dir / "images" / split
|
||||
label_dir = source_dir / "labels" / split
|
||||
if not image_dir.exists() or not label_dir.exists():
|
||||
return {"images": 0, "boxes": 0, "single": 0, "pair": 0}
|
||||
|
||||
stats = {"images": 0, "boxes": 0, "single": 0, "pair": 0}
|
||||
prefix = source_dir.name.replace("-", "_")
|
||||
|
||||
for label_path in sorted(label_dir.glob("*.txt")):
|
||||
image_path = find_image(image_dir, label_path.stem)
|
||||
if image_path is None:
|
||||
continue
|
||||
|
||||
with Image.open(image_path) as image:
|
||||
image = image.convert("RGB")
|
||||
width, height = image.size
|
||||
boxes = load_boxes(label_path, width, height)
|
||||
if not boxes:
|
||||
continue
|
||||
|
||||
samples = make_roi_samples(boxes, width, height)
|
||||
for sample_idx, sample in enumerate(samples):
|
||||
roi_boxes = boxes_in_roi(boxes, sample.roi)
|
||||
if not roi_boxes:
|
||||
continue
|
||||
|
||||
out_stem = f"{prefix}_{label_path.stem}_{sample.mode}_{sample_idx:02d}"
|
||||
dst_image = output_dir / "images" / split / f"{out_stem}.jpg"
|
||||
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
|
||||
|
||||
crop = image.crop((sample.roi.x1, sample.roi.y1, sample.roi.x2, sample.roi.y2))
|
||||
crop.save(dst_image, quality=95)
|
||||
|
||||
yolo_lines = to_yolo_lines(roi_boxes, sample.roi.w, sample.roi.h)
|
||||
dst_label.write_text("\n".join(yolo_lines) + "\n", encoding="utf-8")
|
||||
|
||||
stats["images"] += 1
|
||||
stats["boxes"] += len(roi_boxes)
|
||||
stats[sample.mode] += 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
output_dir = Path(args.output)
|
||||
|
||||
if args.clean and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
ensure_output_layout(output_dir)
|
||||
|
||||
summary: dict[str, dict[str, dict[str, int]]] = defaultdict(dict)
|
||||
totals = {"images": 0, "boxes": 0, "single": 0, "pair": 0}
|
||||
|
||||
for source in args.sources:
|
||||
source_dir = Path(source)
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
|
||||
|
||||
for split in ("train", "val", "test"):
|
||||
stats = build_split(source_dir, output_dir, split)
|
||||
summary[source_dir.name][split] = stats
|
||||
for key in totals:
|
||||
totals[key] += stats[key]
|
||||
|
||||
write_yaml(output_dir, args.sources)
|
||||
|
||||
print(f"Output dataset: {output_dir.resolve()}")
|
||||
for source_name, split_map in summary.items():
|
||||
print(f"[{source_name}]")
|
||||
for split in ("train", "val", "test"):
|
||||
stats = split_map.get(split, {"images": 0, "boxes": 0, "single": 0, "pair": 0})
|
||||
print(
|
||||
f" {split}: rois={stats['images']} boxes={stats['boxes']} "
|
||||
f"single={stats['single']} pair={stats['pair']}"
|
||||
)
|
||||
|
||||
print(
|
||||
"Total:"
|
||||
f" rois={totals['images']} boxes={totals['boxes']}"
|
||||
f" single={totals['single']} pair={totals['pair']}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,6 +1,6 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||
$repo = $PSScriptRoot
|
||||
$env:PYTHONPATH = "$repo\.pydeps"
|
||||
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||
$repo = $PSScriptRoot
|
||||
$env:PYTHONPATH = "$repo\.pydeps"
|
||||
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||
|
||||
|
||||
92
12_train_roi_yolov8s_640.bat
Normal file
92
12_train_roi_yolov8s_640.bat
Normal file
@ -0,0 +1,92 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
set "REPO_DIR=%~dp0"
|
||||
pushd "%REPO_DIR%"
|
||||
|
||||
set "YOLO_LAUNCHER="
|
||||
where yolo >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=yolo"
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
py -3.11 -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=py -3.11 -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
python -c "import ultralytics" >nul 2>nul
|
||||
if %ERRORLEVEL% equ 0 set "YOLO_LAUNCHER=python -m ultralytics"
|
||||
)
|
||||
|
||||
if not defined YOLO_LAUNCHER (
|
||||
echo [错误] 未找到可用的 Ultralytics 启动方式
|
||||
echo.
|
||||
echo 请先确保满足以下任一条件:
|
||||
echo 1. yolo 命令已加入 PATH
|
||||
echo 2. py -3.11 可运行并已安装 ultralytics
|
||||
echo 3. python 可运行并已安装 ultralytics
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
set "DATASET=%REPO_DIR%datasets\shoe-roi-mix\data.yaml"
|
||||
|
||||
if not exist "%DATASET%" (
|
||||
echo [错误] 找不到 ROI 数据集配置: %DATASET%
|
||||
echo.
|
||||
echo 请先构建 ROI 数据集:
|
||||
echo python "%REPO_DIR%09_build_roi_shoe_dataset.py" --clean
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo ============================================================
|
||||
echo 训练鞋子 ROI 检测模型 (YOLOv8s + 640x640)
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo [信息] 数据集: %DATASET%
|
||||
echo [信息] 模型: yolov8s.pt
|
||||
echo.
|
||||
|
||||
set "EPOCHS=150"
|
||||
set "IMGSZ=640"
|
||||
set "BATCH=16"
|
||||
set "PROJECT=%REPO_DIR%runs\roi_yolov8s_640"
|
||||
set "RUN_NAME=train_roi"
|
||||
|
||||
echo 训练参数:
|
||||
echo - Epochs: %EPOCHS%
|
||||
echo - Image Size: %IMGSZ%x%IMGSZ%
|
||||
echo - Batch Size: %BATCH%
|
||||
echo - Device: GPU (cuda:0)
|
||||
echo - Project: %PROJECT%
|
||||
echo - Run Name: %RUN_NAME% ^(已存在时会自动递增,不覆盖旧模型^)
|
||||
echo.
|
||||
|
||||
echo ============================================================
|
||||
echo 开始训练
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
call %YOLO_LAUNCHER% detect train data="%DATASET%" model="yolov8s.pt" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0 project="%PROJECT%" name="%RUN_NAME%"
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [错误] 训练失败!
|
||||
popd
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 训练完成!
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo 模型输出目录: %PROJECT%
|
||||
echo.
|
||||
popd
|
||||
pause
|
||||
63
README.md
63
README.md
@ -1,5 +1,68 @@
|
||||
# 鞋子检测模型训练指南
|
||||
|
||||
## 当前主方案:YOLOv8s-640 + 脚部 ROI 训练
|
||||
|
||||
当前项目的主训练方向已经调整为:
|
||||
- 只训练 `yolov8s`、输入尺寸固定 `640x640`
|
||||
- 训练数据不再直接使用“整张场景图”或“鞋子纯特写图”
|
||||
- 先根据鞋框裁出更接近线上输入分布的“脚部 ROI 图”,再训练鞋检测模型
|
||||
|
||||
这样做的原因是线上链路并不是直接在整张图上找鞋,而是:
|
||||
1. 先从人体框生成脚部 ROI
|
||||
2. 再在脚部 ROI 上做鞋检测
|
||||
|
||||
因此训练阶段也尽量模拟这个输入分布,保留一些裤脚、地面和周围背景,避免训练样本过于像商品特写。
|
||||
|
||||
### ROI 规则
|
||||
|
||||
单鞋 ROI:
|
||||
- 已知鞋框 `(x, y, w, h)`
|
||||
- `roi_x = x - 0.6w`
|
||||
- `roi_y = y - 0.5h`
|
||||
- `roi_w = 2.2w`
|
||||
- `roi_h = 2.4h`
|
||||
|
||||
双鞋 ROI:
|
||||
- 优先把两只鞋裁进同一张 ROI
|
||||
- 先取两只鞋框并集,再扩框:
|
||||
- `roi_x = union_x - 0.35 * union_w`
|
||||
- `roi_y = union_y - 0.45 * union_h`
|
||||
- `roi_w = 1.7 * union_w`
|
||||
- `roi_h = 2.0 * union_h`
|
||||
|
||||
裁图会自动裁剪到图像边界内。
|
||||
|
||||
### 新主流程
|
||||
|
||||
1. 准备原始单类鞋数据集
|
||||
|
||||
```bash
|
||||
python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
python 05_prepare_ppe_shoe_subset.py
|
||||
```
|
||||
|
||||
2. 构建 ROI 化训练集
|
||||
|
||||
```bash
|
||||
python 09_build_roi_shoe_dataset.py --clean
|
||||
```
|
||||
|
||||
输出目录:
|
||||
- `datasets/shoe-roi-mix`
|
||||
|
||||
3. 训练新的 ROI 模型
|
||||
|
||||
```bash
|
||||
12_train_roi_yolov8s_640.bat
|
||||
```
|
||||
|
||||
模型输出目录:
|
||||
- `runs/roi_yolov8s_640`
|
||||
|
||||
说明:
|
||||
- 新模型会写到新的项目目录,不覆盖之前已有模型
|
||||
- 如果 `train_roi` 已存在,Ultralytics 会自动递增运行目录名
|
||||
|
||||
## 方案:640x640 单模型(部署时用2窗口)
|
||||
|
||||
**训练阶段**:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user