项目初始化;已经训练了三种yolo版本的2种尺寸的模型
This commit is contained in:
commit
45f1f3182f
48
.gitignore
vendored
Normal file
48
.gitignore
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
# Python cache and local tooling
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.coverage
|
||||
.coverage.*
|
||||
|
||||
# Virtualenvs and local dependency folders
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.pydeps/
|
||||
|
||||
# Project-local tool state
|
||||
.ultralytics/
|
||||
|
||||
# OS / editor noise
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
*.tmp
|
||||
*.temp
|
||||
|
||||
# Large datasets and generated training assets
|
||||
datasets/
|
||||
runs/
|
||||
|
||||
# Exported / downloaded model artifacts
|
||||
pretrained/
|
||||
*.pt
|
||||
*.pth
|
||||
*.onnx
|
||||
*.rknn
|
||||
*.engine
|
||||
*.bin
|
||||
|
||||
# Logs and reports
|
||||
*.log
|
||||
|
||||
# Keep placeholder files in otherwise ignored trees
|
||||
!datasets/.gitkeep
|
||||
!runs/.gitkeep
|
||||
!pretrained/.gitkeep
|
||||
411
01_download_dataset.py
Normal file
411
01_download_dataset.py
Normal file
@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
下载鞋子检测数据集。
|
||||
|
||||
支持:
|
||||
- Ultralytics Construction-PPE
|
||||
- Open Images V7 (推荐用于单类 shoe 检测)
|
||||
|
||||
Open Images 推荐类别:
|
||||
- Footwear
|
||||
- Boot
|
||||
|
||||
可选补充:
|
||||
- Sandal
|
||||
|
||||
不建议默认加入:
|
||||
- High heels
|
||||
- Roller skates
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import ssl
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
OPENIMAGES_RECOMMENDED_CLASSES = ["Footwear", "Boot"]
|
||||
OPENIMAGES_OPTIONAL_CLASSES = ["Sandal"]
|
||||
OPENIMAGES_NOT_RECOMMENDED_CLASSES = ["High heels", "Roller skates"]
|
||||
|
||||
|
||||
def download_ultralytics_cppe(dataset_dir: str = "datasets/construction-ppe"):
|
||||
"""下载 Ultralytics Construction-PPE 数据集。"""
|
||||
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/construction-ppe.zip"
|
||||
zip_path = "construction-ppe.zip"
|
||||
|
||||
print("=" * 70)
|
||||
print("下载 Construction-PPE 数据集")
|
||||
print("=" * 70)
|
||||
print(f"来源: {url}")
|
||||
print(f"目标: {dataset_dir}")
|
||||
print()
|
||||
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
print("[1/3] 下载中... (约 178MB)")
|
||||
try:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
with urllib.request.urlopen(url, context=ssl_context, timeout=300) as response:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
downloaded = 0
|
||||
chunk_size = 8192
|
||||
|
||||
with open(zip_path, "wb") as f:
|
||||
while True:
|
||||
chunk = response.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded += len(chunk)
|
||||
if total_size > 0:
|
||||
percent = (downloaded / total_size) * 100
|
||||
print(
|
||||
f"\r 进度: {percent:.1f}% ({downloaded}/{total_size} bytes)",
|
||||
end="",
|
||||
)
|
||||
|
||||
print("\n ✓ 下载完成")
|
||||
except Exception as e:
|
||||
print(f"\n ✗ 下载失败: {e}")
|
||||
print("\n请手动下载:")
|
||||
print(f" 1. 访问: {url}")
|
||||
print(" 2. 下载 construction-ppe.zip")
|
||||
print(f" 3. 解压到 {dataset_dir}/")
|
||||
return False
|
||||
|
||||
print("\n[2/3] 解压中...")
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(dataset_dir)
|
||||
print(" ✓ 解压完成")
|
||||
except Exception as e:
|
||||
print(f" ✗ 解压失败: {e}")
|
||||
return False
|
||||
|
||||
print("\n[3/3] 清理临时文件...")
|
||||
os.remove(zip_path)
|
||||
print(" ✓ 完成")
|
||||
return True
|
||||
|
||||
|
||||
def create_single_class_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
|
||||
"""创建单类 shoe 检测配置文件。"""
|
||||
yaml_content = f"""# 单类鞋子检测数据集配置
|
||||
|
||||
path: {dataset_path_value}
|
||||
train: images/train
|
||||
val: images/val
|
||||
test: images/test
|
||||
|
||||
nc: 1
|
||||
names: ['shoe']
|
||||
|
||||
dataset_info:
|
||||
name: {source_name}
|
||||
task: detect_shoe
|
||||
note: 所有鞋类子类统一映射为单一类别 shoe
|
||||
"""
|
||||
|
||||
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||||
with open(yaml_path, "w", encoding="utf-8") as f:
|
||||
f.write(yaml_content)
|
||||
|
||||
print(f"\n✓ 配置文件创建: {yaml_path}")
|
||||
return yaml_path
|
||||
|
||||
|
||||
def rewrite_yaml_for_existing_splits(dataset_dir: str, source_name: str):
|
||||
"""根据现有目录结构重写 data.yaml。"""
|
||||
images_root = Path(dataset_dir) / "images"
|
||||
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
|
||||
|
||||
if not split_names:
|
||||
raise RuntimeError(f"未找到任何图像 split: {images_root}")
|
||||
|
||||
train_split = "train" if "train" in split_names else split_names[0]
|
||||
val_split = "val" if "val" in split_names else train_split
|
||||
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
|
||||
|
||||
yaml_content = f"""# 单类鞋子检测数据集配置
|
||||
|
||||
path: .
|
||||
train: images/{train_split}
|
||||
val: images/{val_split}
|
||||
{test_line}
|
||||
|
||||
nc: 1
|
||||
names: ['shoe']
|
||||
|
||||
dataset_info:
|
||||
name: {source_name}
|
||||
task: detect_shoe
|
||||
note: 所有鞋类子类统一映射为单一类别 shoe
|
||||
"""
|
||||
|
||||
yaml_path = Path(dataset_dir) / "data.yaml"
|
||||
yaml_path.write_text(yaml_content, encoding="utf-8")
|
||||
print(f"\n✓ 配置文件更新: {yaml_path}")
|
||||
return str(yaml_path)
|
||||
|
||||
|
||||
def ensure_openimages_train_val_split(export_dir: str, train_ratio: float = 0.9, seed: int = 42):
|
||||
"""如果导出结果只有单个 split,则自动切分为 train/val。"""
|
||||
images_root = Path(export_dir) / "images"
|
||||
labels_root = Path(export_dir) / "labels"
|
||||
train_images = images_root / "train"
|
||||
val_images = images_root / "val"
|
||||
train_labels = labels_root / "train"
|
||||
val_labels = labels_root / "val"
|
||||
|
||||
if train_images.exists() and train_labels.exists():
|
||||
return
|
||||
|
||||
if not val_images.exists() or not val_labels.exists():
|
||||
return
|
||||
|
||||
image_files = sorted([p for p in val_images.iterdir() if p.is_file()])
|
||||
if len(image_files) < 2:
|
||||
return
|
||||
|
||||
train_images.mkdir(parents=True, exist_ok=True)
|
||||
train_labels.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rng = random.Random(seed)
|
||||
rng.shuffle(image_files)
|
||||
|
||||
split_idx = max(1, min(len(image_files) - 1, int(len(image_files) * train_ratio)))
|
||||
train_files = image_files[:split_idx]
|
||||
|
||||
for image_path in train_files:
|
||||
label_path = val_labels / f"{image_path.stem}.txt"
|
||||
shutil.move(str(image_path), train_images / image_path.name)
|
||||
if label_path.exists():
|
||||
shutil.move(str(label_path), train_labels / label_path.name)
|
||||
|
||||
print("\n自动切分单一 split 为 train/val:")
|
||||
print(f" train: {len(list(train_images.iterdir()))} 张")
|
||||
print(f" val: {len(list(val_images.iterdir()))} 张")
|
||||
|
||||
|
||||
def merge_openimages_to_single_class(export_dir: str):
|
||||
"""将 Open Images 导出的多类别 YOLO 标签合并为单类 shoe。"""
|
||||
labels_root = Path(export_dir) / "labels"
|
||||
if not labels_root.exists():
|
||||
print(f"✗ 未找到标签目录: {labels_root}")
|
||||
return False
|
||||
|
||||
total_files = 0
|
||||
total_boxes = 0
|
||||
changed_files = 0
|
||||
|
||||
for label_file in labels_root.rglob("*.txt"):
|
||||
total_files += 1
|
||||
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||||
rewritten = []
|
||||
file_changed = False
|
||||
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
parts[0] = "0"
|
||||
rewritten.append(" ".join(parts))
|
||||
total_boxes += 1
|
||||
file_changed = True
|
||||
|
||||
label_file.write_text("\n".join(rewritten) + ("\n" if rewritten else ""), encoding="utf-8")
|
||||
|
||||
if file_changed:
|
||||
changed_files += 1
|
||||
|
||||
ensure_openimages_train_val_split(export_dir)
|
||||
rewrite_yaml_for_existing_splits(export_dir, "Open Images V7")
|
||||
|
||||
print("\n单类合并完成:")
|
||||
print(f" 标签文件: {changed_files}/{total_files}")
|
||||
print(f" 标注框数: {total_boxes}")
|
||||
return True
|
||||
|
||||
|
||||
def download_openimages(classes: list, max_samples: int, dataset_dir: str):
|
||||
"""通过 FiftyOne 下载 Open Images 并导出为单类 shoe 数据集。"""
|
||||
try:
|
||||
import fiftyone as fo
|
||||
import fiftyone.zoo as foz
|
||||
except ImportError:
|
||||
print("错误: 未安装 fiftyone")
|
||||
print("请运行: pip install fiftyone fiftyone-db-ubuntu2204")
|
||||
return False
|
||||
|
||||
export_dir = dataset_dir + "-yolo"
|
||||
|
||||
print("=" * 70)
|
||||
print("下载 Open Images V7 数据集")
|
||||
print("=" * 70)
|
||||
print(f"类别: {classes}")
|
||||
print(f"最大样本数: {max_samples}")
|
||||
print("原始缓存目录: FiftyOne 默认缓存目录")
|
||||
print(f"YOLO 导出目录: {export_dir}")
|
||||
print()
|
||||
|
||||
if os.path.exists(export_dir):
|
||||
print(f"检测到已有导出目录,先删除: {export_dir}")
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
try:
|
||||
dataset = foz.load_zoo_dataset(
|
||||
"open-images-v7",
|
||||
split="train",
|
||||
label_types=["detections"],
|
||||
classes=classes,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
print("\n导出为 YOLO 格式...")
|
||||
dataset.export(
|
||||
export_dir=export_dir,
|
||||
dataset_type=fo.types.YOLOv5Dataset,
|
||||
label_field="ground_truth",
|
||||
)
|
||||
|
||||
if not merge_openimages_to_single_class(export_dir):
|
||||
return False
|
||||
|
||||
print(f"\n✓ 数据集保存: {export_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"✗ 下载失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_dataset(dataset_dir: str):
|
||||
"""检查数据集完整性。"""
|
||||
print("\n" + "=" * 70)
|
||||
print("检查数据集")
|
||||
print("=" * 70)
|
||||
|
||||
required_dirs = [
|
||||
"images/train",
|
||||
"images/val",
|
||||
"images/test",
|
||||
"labels/train",
|
||||
"labels/val",
|
||||
"labels/test",
|
||||
]
|
||||
|
||||
all_ok = True
|
||||
for dir_name in required_dirs:
|
||||
full_path = os.path.join(dataset_dir, dir_name)
|
||||
if os.path.exists(full_path):
|
||||
count = len(
|
||||
[f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))]
|
||||
)
|
||||
print(f" ✓ {dir_name}: {count} 个文件")
|
||||
else:
|
||||
print(f" ✗ {dir_name}: 不存在")
|
||||
all_ok = False
|
||||
|
||||
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||||
if os.path.exists(yaml_path):
|
||||
print(f" ✓ data.yaml: 已生成")
|
||||
else:
|
||||
print(f" ✗ data.yaml: 不存在")
|
||||
all_ok = False
|
||||
|
||||
return all_ok
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="下载鞋子检测数据集",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 下载 Construction-PPE
|
||||
python 01_download_dataset.py --source ultralytics
|
||||
|
||||
# 下载 Open Images 推荐鞋子类别
|
||||
python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
|
||||
# 如需补充凉鞋样本
|
||||
python 01_download_dataset.py --source openimages --classes Footwear Boot Sandal --max-samples 8000
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
choices=["ultralytics", "openimages"],
|
||||
default="openimages",
|
||||
help="数据源 (默认: openimages)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dir",
|
||||
default="datasets/openimages-shoes",
|
||||
help="数据集保存目录",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-samples",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Open Images 最大样本数 (默认: 5000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classes",
|
||||
nargs="+",
|
||||
default=OPENIMAGES_RECOMMENDED_CLASSES,
|
||||
help="Open Images 类别 (默认: Footwear Boot)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
success = False
|
||||
final_dataset_dir = args.dir
|
||||
|
||||
print("=" * 70)
|
||||
print("鞋子检测数据集准备")
|
||||
print("=" * 70)
|
||||
if args.source == "openimages":
|
||||
print(f"推荐类别: {OPENIMAGES_RECOMMENDED_CLASSES}")
|
||||
print(f"可选补充: {OPENIMAGES_OPTIONAL_CLASSES}")
|
||||
print(f"默认不建议: {OPENIMAGES_NOT_RECOMMENDED_CLASSES}")
|
||||
print()
|
||||
|
||||
if args.source == "ultralytics":
|
||||
success = download_ultralytics_cppe(args.dir)
|
||||
if success:
|
||||
create_single_class_yaml(args.dir, "Construction-PPE", dataset_path_value="construction-ppe")
|
||||
check_dataset(args.dir)
|
||||
|
||||
elif args.source == "openimages":
|
||||
success = download_openimages(args.classes, args.max_samples, args.dir)
|
||||
final_dataset_dir = args.dir + "-yolo"
|
||||
if success:
|
||||
check_dataset(final_dataset_dir)
|
||||
|
||||
if success:
|
||||
print("\n" + "=" * 70)
|
||||
print("数据集准备完成!")
|
||||
print("=" * 70)
|
||||
print(f"训练数据集路径: {final_dataset_dir}")
|
||||
print("\n下一步:")
|
||||
print(f" 1. 检查配置: {final_dataset_dir}/data.yaml")
|
||||
print(f" 2. 开始训练: 02_train.bat")
|
||||
print(
|
||||
f" 3. 或手动: yolo detect train data={final_dataset_dir}/data.yaml model=yolov8s.pt epochs=150 imgsz=640"
|
||||
)
|
||||
return 0
|
||||
|
||||
print("\n✗ 数据集准备失败")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
90
02_train.bat
Normal file
90
02_train.bat
Normal file
@ -0,0 +1,90 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
:: 设置 Python 3.11 路径
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
|
||||
echo ============================================================
|
||||
echo 训练鞋子检测模型 (YOLOv8 + 640x640)
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 设置数据集路径
|
||||
set DATASET=datasets/openimages-shoes-yolo/data.yaml
|
||||
|
||||
:: 检查数据集是否存在
|
||||
if not exist %DATASET% (
|
||||
echo [错误] 找不到数据集配置文件: %DATASET%
|
||||
echo.
|
||||
echo 请先下载数据集:
|
||||
echo python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [信息] 数据集: %DATASET%
|
||||
echo.
|
||||
|
||||
:: 选择模型
|
||||
echo 选择模型:
|
||||
echo 1. YOLOv8n (轻量级, 速度快)
|
||||
echo 2. YOLOv8s (推荐, 速度和精度平衡)
|
||||
echo 3. YOLOv8m (高精度, 较慢)
|
||||
echo.
|
||||
set /p MODEL_CHOICE="输入选择 (1-3, 默认 2): "
|
||||
|
||||
if "%MODEL_CHOICE%"=="" set MODEL_CHOICE=2
|
||||
if "%MODEL_CHOICE%"=="1" (
|
||||
set MODEL=yolov8n.pt
|
||||
set DESC=YOLOv8n
|
||||
)
|
||||
if "%MODEL_CHOICE%"=="2" (
|
||||
set MODEL=yolov8s.pt
|
||||
set DESC=YOLOv8s (推荐)
|
||||
)
|
||||
if "%MODEL_CHOICE%"=="3" (
|
||||
set MODEL=yolov8m.pt
|
||||
set DESC=YOLOv8m
|
||||
)
|
||||
|
||||
echo.
|
||||
echo [信息] 使用模型: %DESC%
|
||||
echo.
|
||||
|
||||
:: 训练参数
|
||||
set EPOCHS=150
|
||||
set IMGSZ=640
|
||||
set BATCH=16
|
||||
|
||||
echo 训练参数:
|
||||
echo - Epochs: %EPOCHS%
|
||||
echo - Image Size: %IMGSZ%x%IMGSZ%
|
||||
echo - Batch Size: %BATCH%
|
||||
echo - Device: GPU (cuda:0)
|
||||
echo.
|
||||
|
||||
echo ============================================================
|
||||
echo 开始训练
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
yolo detect train data=%DATASET% model=%MODEL% epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [错误] 训练失败!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 训练完成!
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo 模型保存在: runs/detect/train/weights/best.pt
|
||||
echo.
|
||||
echo 下一步: 运行 03_export_onnx.bat 导出 ONNX
|
||||
echo.
|
||||
pause
|
||||
38
03_export_onnx.bat
Normal file
38
03_export_onnx.bat
Normal file
@ -0,0 +1,38 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
:: 设置 Python 3.11 路径
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
|
||||
echo ============================================================
|
||||
echo 导出 ONNX 模型 (640x640)
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
set MODEL_PATH=%USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.pt
|
||||
|
||||
if not exist %MODEL_PATH% (
|
||||
echo [错误] 找不到模型: %MODEL_PATH%
|
||||
echo 请先运行 02_train.bat 训练
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [信息] 输入模型: %MODEL_PATH%
|
||||
echo.
|
||||
|
||||
yolo export model=%MODEL_PATH% format=onnx imgsz=640 opset=12 simplify
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo [错误] 导出失败!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo [成功] ONNX 模型: %USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.onnx
|
||||
echo.
|
||||
echo 下一步: 在 Ubuntu 上运行 04_convert_rknn.py 转换
|
||||
echo.
|
||||
pause
|
||||
262
04_convert_rknn.py
Normal file
262
04_convert_rknn.py
Normal file
@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将 YOLOv8 ONNX 模型转换为 RKNN 格式
|
||||
适用于 RK3588 / RK3568 / RK3576 等平台
|
||||
|
||||
环境要求:
|
||||
- Ubuntu x86_64 / Docker
|
||||
- Python 3.8 / 3.9 / 3.10 / 3.11
|
||||
- rknn-toolkit2 (pip install rknn-toolkit2==2.2.0)
|
||||
|
||||
使用方法:
|
||||
# FP16 模式(推荐,速度快精度高)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588
|
||||
|
||||
# INT8 量化(模型更小,需要校准数据集)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588 -q -d dataset.txt
|
||||
|
||||
支持的 target_platform:
|
||||
- rk3588 / rk3588s
|
||||
- rk3568 / rk3566
|
||||
- rk3576
|
||||
- rv1106 / rv1103 / rv1103b
|
||||
- rv1126
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_environment():
|
||||
"""检查运行环境"""
|
||||
try:
|
||||
from rknn.api import RKNN
|
||||
print("✓ RKNN Toolkit2 已安装")
|
||||
return True
|
||||
except ImportError:
|
||||
print("✗ 错误: 未安装 RKNN Toolkit2")
|
||||
print("\n请安装:")
|
||||
print(" pip install rknn-toolkit2==2.2.0")
|
||||
print("\n或从源码安装:")
|
||||
print(" https://github.com/airockchip/rknn-toolkit2")
|
||||
return False
|
||||
|
||||
|
||||
def create_sample_dataset(onnx_path: str, output_path: str = "dataset.txt", num_samples: int = 20):
|
||||
"""
|
||||
创建示例量化校准数据集
|
||||
用于 INT8 量化时提供校准图片路径
|
||||
"""
|
||||
print(f"\n创建示例校准数据集: {output_path}")
|
||||
print("注意: 请用实际图片替换这些示例路径")
|
||||
|
||||
sample_content = f"""# RKNN INT8 量化校准数据集
|
||||
# 每行一个图片路径,建议使用 20-100 张典型场景图片
|
||||
# 图片格式: JPG, PNG, BMP 等
|
||||
|
||||
# 示例路径(请替换为实际路径):
|
||||
# /path/to/train/images/img001.jpg
|
||||
# /path/to/train/images/img002.jpg
|
||||
# /path/to/valid/images/img001.jpg
|
||||
|
||||
# 提示:
|
||||
# 1. 图片应与实际部署场景相似
|
||||
# 2. 包含各种光照、角度、背景的样本
|
||||
# 3. 建议 20-100 张,越多越慢但可能更准
|
||||
"""
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(sample_content)
|
||||
|
||||
print(f"✓ 示例数据集已创建: {output_path}")
|
||||
print(" 请编辑此文件,添加实际的图片路径")
|
||||
return output_path
|
||||
|
||||
|
||||
def convert_onnx_to_rknn(
|
||||
onnx_path: str,
|
||||
output_path: str = None,
|
||||
target_platform: str = "rk3588",
|
||||
do_quantization: bool = False,
|
||||
dataset_path: str = None,
|
||||
verbose: bool = True
|
||||
):
|
||||
"""
|
||||
转换 ONNX 模型到 RKNN
|
||||
|
||||
Args:
|
||||
onnx_path: ONNX 模型文件路径
|
||||
output_path: 输出 RKNN 文件路径,默认与 ONNX 同名
|
||||
target_platform: 目标平台,默认 rk3588
|
||||
do_quantization: 是否启用 INT8 量化
|
||||
dataset_path: 量化校准数据集路径(txt 文件,每行一张图片路径)
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = onnx_path.replace(".onnx", ".rknn")
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
print("="*70)
|
||||
print(f"ONNX 转 RKNN")
|
||||
print("="*70)
|
||||
print(f"输入: {onnx_path}")
|
||||
print(f"输出: {output_path}")
|
||||
print(f"目标: {target_platform}")
|
||||
print(f"量化: {'INT8' if do_quantization else 'FP16 (无量化)'}")
|
||||
if do_quantization:
|
||||
print(f"校准: {dataset_path}")
|
||||
print("="*70)
|
||||
|
||||
# 检查输入文件
|
||||
if not os.path.exists(onnx_path):
|
||||
print(f"\n✗ 错误: 找不到 ONNX 文件: {onnx_path}")
|
||||
return False
|
||||
|
||||
# 检查数据集(如果需要量化)
|
||||
if do_quantization:
|
||||
if dataset_path is None:
|
||||
print("\n✗ 错误: INT8 量化需要提供校准数据集")
|
||||
print(" 使用 --dataset 指定数据集文件路径")
|
||||
print(" 或运行 --create-dataset 创建示例")
|
||||
return False
|
||||
if not os.path.exists(dataset_path):
|
||||
print(f"\n✗ 错误: 找不到数据集文件: {dataset_path}")
|
||||
return False
|
||||
|
||||
from rknn.api import RKNN
|
||||
|
||||
# 创建 RKNN 对象
|
||||
rknn = RKNN(verbose=verbose)
|
||||
|
||||
try:
|
||||
# 配置模型
|
||||
print("\n[1/4] 配置模型...")
|
||||
rknn.config(
|
||||
mean_values=[[0, 0, 0]], # YOLOv8 使用 0-255 输入
|
||||
std_values=[[255, 255, 255]], # 归一化到 0-1
|
||||
target_platform=target_platform
|
||||
)
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 加载 ONNX
|
||||
print("\n[2/4] 加载 ONNX 模型...")
|
||||
ret = rknn.load_onnx(model=onnx_path)
|
||||
if ret != 0:
|
||||
print(" ✗ 加载失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 构建模型
|
||||
print("\n[3/4] 构建 RKNN 模型...")
|
||||
if do_quantization:
|
||||
print(f" 使用 INT8 量化,校准数据集: {dataset_path}")
|
||||
ret = rknn.build(do_quantization=True, dataset=dataset_path)
|
||||
else:
|
||||
print(" 使用 FP16 模式(无量化)")
|
||||
ret = rknn.build(do_quantization=False)
|
||||
|
||||
if ret != 0:
|
||||
print(" ✗ 构建失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 导出 RKNN
|
||||
print("\n[4/4] 导出 RKNN 模型...")
|
||||
ret = rknn.export_rknn(output_path)
|
||||
if ret != 0:
|
||||
print(" ✗ 导出失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
finally:
|
||||
rknn.release()
|
||||
|
||||
# 验证输出
|
||||
if os.path.exists(output_path):
|
||||
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
||||
print("\n" + "="*70)
|
||||
print(f"✓ 转换成功!")
|
||||
print(f" 输出文件: {output_path}")
|
||||
print(f" 文件大小: {size_mb:.2f} MB")
|
||||
print("="*70)
|
||||
|
||||
print("\n下一步:")
|
||||
print(f" 1. 复制到 RK3588:")
|
||||
print(f" scp {output_path} orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/")
|
||||
print(f" 2. 更新配置文件中的模型路径")
|
||||
return True
|
||||
else:
|
||||
print("\n✗ 错误: 输出文件未生成")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="将 YOLOv8 ONNX 模型转换为 RKNN",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# FP16 模式(推荐)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn
|
||||
|
||||
# 指定目标平台
|
||||
python 04_convert_rknn.py best.onnx -t rk3568
|
||||
|
||||
# INT8 量化
|
||||
python 04_convert_rknn.py best.onnx -q -d dataset.txt
|
||||
|
||||
# 创建示例校准数据集
|
||||
python 04_convert_rknn.py --create-dataset
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("onnx", nargs="?", help="ONNX 模型文件路径")
|
||||
parser.add_argument("-o", "--output", help="输出 RKNN 文件路径")
|
||||
parser.add_argument("-t", "--target", default="rk3588",
|
||||
help="目标平台 (默认: rk3588)")
|
||||
parser.add_argument("-q", "--quantize", action="store_true",
|
||||
help="启用 INT8 量化")
|
||||
parser.add_argument("-d", "--dataset", help="量化校准数据集路径 (txt 文件)")
|
||||
parser.add_argument("--create-dataset", action="store_true",
|
||||
help="创建示例校准数据集并退出")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", default=True,
|
||||
help="显示详细信息")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建示例数据集
|
||||
if args.create_dataset:
|
||||
create_sample_dataset("best.onnx", "dataset.txt")
|
||||
return 0
|
||||
|
||||
# 检查参数
|
||||
if args.onnx is None:
|
||||
parser.print_help()
|
||||
print("\n错误: 请提供 ONNX 文件路径")
|
||||
return 1
|
||||
|
||||
# 检查环境
|
||||
if not check_environment():
|
||||
return 1
|
||||
|
||||
# 执行转换
|
||||
success = convert_onnx_to_rknn(
|
||||
onnx_path=args.onnx,
|
||||
output_path=args.output,
|
||||
target_platform=args.target,
|
||||
do_quantization=args.quantize,
|
||||
dataset_path=args.dataset,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
119
05_prepare_ppe_shoe_subset.py
Normal file
119
05_prepare_ppe_shoe_subset.py
Normal file
@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python3
|
||||
"""从 Construction-PPE 中提取单类 shoe 子集。"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
SHOE_CLASSES = {"3", "10"} # boots, no_boots
|
||||
|
||||
|
||||
def ensure_clean_dir(path: Path):
|
||||
if path.exists():
|
||||
shutil.rmtree(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def write_yaml(output_dir: Path):
|
||||
yaml_path = output_dir / "data.yaml"
|
||||
abs_output = output_dir.resolve().as_posix()
|
||||
yaml_path.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"# PPE 鞋子单类微调数据集配置",
|
||||
"",
|
||||
f"path: {abs_output}",
|
||||
"train: images/train",
|
||||
"val: images/val",
|
||||
"test: images/test",
|
||||
"",
|
||||
"nc: 1",
|
||||
"names: ['shoe']",
|
||||
"",
|
||||
"dataset_info:",
|
||||
" name: Construction-PPE shoe subset",
|
||||
" source: Construction-PPE",
|
||||
" note: 从 boots 和 no_boots 提取并统一映射为 shoe",
|
||||
"",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def convert_split(source_dir: Path, output_dir: Path, split: str):
|
||||
image_src = source_dir / "images" / split
|
||||
label_src = source_dir / "labels" / split
|
||||
image_dst = output_dir / "images" / split
|
||||
label_dst = output_dir / "labels" / split
|
||||
|
||||
image_dst.mkdir(parents=True, exist_ok=True)
|
||||
label_dst.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
kept_images = 0
|
||||
kept_boxes = 0
|
||||
|
||||
for label_file in sorted(label_src.glob("*.txt")):
|
||||
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||||
shoe_lines = []
|
||||
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 5 or parts[0] not in SHOE_CLASSES:
|
||||
continue
|
||||
parts[0] = "0"
|
||||
shoe_lines.append(" ".join(parts))
|
||||
|
||||
if not shoe_lines:
|
||||
continue
|
||||
|
||||
image_file = image_src / f"{label_file.stem}.jpg"
|
||||
if not image_file.exists():
|
||||
image_file = image_src / f"{label_file.stem}.png"
|
||||
if not image_file.exists():
|
||||
continue
|
||||
|
||||
shutil.copy2(image_file, image_dst / image_file.name)
|
||||
(label_dst / label_file.name).write_text("\n".join(shoe_lines) + "\n", encoding="utf-8")
|
||||
kept_images += 1
|
||||
kept_boxes += len(shoe_lines)
|
||||
|
||||
return kept_images, kept_boxes
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="提取 PPE 鞋子单类子集")
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
default="datasets/construction-ppe",
|
||||
help="Construction-PPE 数据集目录",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="datasets/ppe-shoes",
|
||||
help="输出目录",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
source_dir = Path(args.source)
|
||||
output_dir = Path(args.output)
|
||||
|
||||
ensure_clean_dir(output_dir)
|
||||
|
||||
total_images = 0
|
||||
total_boxes = 0
|
||||
for split in ("train", "val", "test"):
|
||||
kept_images, kept_boxes = convert_split(source_dir, output_dir, split)
|
||||
total_images += kept_images
|
||||
total_boxes += kept_boxes
|
||||
print(f"[{split}] images={kept_images} boxes={kept_boxes}")
|
||||
|
||||
write_yaml(output_dir)
|
||||
print(f"\n输出目录: {output_dir}")
|
||||
print(f"总图片数: {total_images}")
|
||||
print(f"总框数: {total_boxes}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
62
06_finetune_ppe.bat
Normal file
62
06_finetune_ppe.bat
Normal file
@ -0,0 +1,62 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||
|
||||
echo ============================================================
|
||||
echo Stage 2 Fine-tuning on PPE shoe subset
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
set "DATASET=C:\Users\Tellme\apps\ppe-model-training\datasets\ppe-shoes\data.yaml"
|
||||
set "BASE_MODEL=C:\Users\Tellme\apps\ultralytics\runs\detect\train3\weights\best.pt"
|
||||
|
||||
if not exist "%DATASET%" (
|
||||
echo [ERROR] PPE shoe subset not found: %DATASET%
|
||||
echo Run:
|
||||
echo C:\Users\Tellme\AppData\Local\Programs\Python\Python311\python.exe C:\Users\Tellme\apps\ppe-model-training\05_prepare_ppe_shoe_subset.py
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if not exist "%BASE_MODEL%" (
|
||||
echo [ERROR] Base model not found: %BASE_MODEL%
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [INFO] Dataset: %DATASET%
|
||||
echo [INFO] Base model: %BASE_MODEL%
|
||||
echo.
|
||||
|
||||
set "EPOCHS=50"
|
||||
set "IMGSZ=640"
|
||||
set "BATCH=16"
|
||||
|
||||
echo Fine-tune params:
|
||||
echo Epochs: %EPOCHS%
|
||||
echo Image Size: %IMGSZ%x%IMGSZ%
|
||||
echo Batch Size: %BATCH%
|
||||
echo Device: GPU (cuda:0)
|
||||
echo.
|
||||
|
||||
echo ============================================================
|
||||
echo Start fine-tuning
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
yolo detect train data="%DATASET%" model="%BASE_MODEL%" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [ERROR] Fine-tuning failed.
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo Fine-tuning complete
|
||||
echo ============================================================
|
||||
echo.
|
||||
pause
|
||||
145
07_build_public_shoe_dataset.py
Normal file
145
07_build_public_shoe_dataset.py
Normal file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a merged single-class public shoe dataset for YOLO training."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DEFAULT_SOURCES = [
|
||||
"datasets/openimages-shoes-yolo",
|
||||
"datasets/ppe-shoes",
|
||||
]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Merge public shoe datasets into one YOLO dataset")
|
||||
parser.add_argument(
|
||||
"--sources",
|
||||
nargs="+",
|
||||
default=DEFAULT_SOURCES,
|
||||
help="Source dataset directories containing images/<split> and labels/<split>",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="datasets/shoe-public-mix",
|
||||
help="Output merged dataset directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clean",
|
||||
action="store_true",
|
||||
help="Delete the output directory before rebuilding",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def ensure_output_layout(output_dir: Path) -> None:
|
||||
for split in ("train", "val", "test"):
|
||||
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def copy_split(source_dir: Path, output_dir: Path, split: str) -> tuple[int, int]:
|
||||
image_dir = source_dir / "images" / split
|
||||
label_dir = source_dir / "labels" / split
|
||||
if not image_dir.exists() or not label_dir.exists():
|
||||
return 0, 0
|
||||
|
||||
images_copied = 0
|
||||
boxes_copied = 0
|
||||
prefix = source_dir.name.replace("-", "_")
|
||||
|
||||
for label_file in sorted(label_dir.glob("*.txt")):
|
||||
image_file = None
|
||||
for ext in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
||||
candidate = image_dir / f"{label_file.stem}{ext}"
|
||||
if candidate.exists():
|
||||
image_file = candidate
|
||||
break
|
||||
|
||||
if image_file is None:
|
||||
continue
|
||||
|
||||
lines = [line.strip() for line in label_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||
if not lines:
|
||||
continue
|
||||
|
||||
out_stem = f"{prefix}_{label_file.stem}"
|
||||
dst_image = output_dir / "images" / split / f"{out_stem}{image_file.suffix.lower()}"
|
||||
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
|
||||
|
||||
shutil.copy2(image_file, dst_image)
|
||||
dst_label.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
images_copied += 1
|
||||
boxes_copied += len(lines)
|
||||
|
||||
return images_copied, boxes_copied
|
||||
|
||||
|
||||
def write_yaml(output_dir: Path) -> None:
|
||||
yaml_path = output_dir / "data.yaml"
|
||||
yaml_path.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"# Public shoe training mix",
|
||||
"",
|
||||
f"path: {output_dir.resolve().as_posix()}",
|
||||
"train: images/train",
|
||||
"val: images/val",
|
||||
"test: images/test",
|
||||
"",
|
||||
"nc: 1",
|
||||
"names: ['shoe']",
|
||||
"",
|
||||
"dataset_info:",
|
||||
" name: shoe-public-mix",
|
||||
" task: detect_shoe",
|
||||
" note: merged Open Images shoe data and PPE shoe subset",
|
||||
"",
|
||||
]
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
output_dir = Path(args.output)
|
||||
|
||||
if args.clean and output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
ensure_output_layout(output_dir)
|
||||
|
||||
summary: dict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
||||
for source in args.sources:
|
||||
source_dir = Path(source)
|
||||
if not source_dir.exists():
|
||||
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
|
||||
|
||||
for split in ("train", "val", "test"):
|
||||
summary[source_dir.name][split] = copy_split(source_dir, output_dir, split)
|
||||
|
||||
write_yaml(output_dir)
|
||||
|
||||
print(f"Output dataset: {output_dir.resolve()}")
|
||||
total_images = 0
|
||||
total_boxes = 0
|
||||
for source_name, split_map in summary.items():
|
||||
print(f"[{source_name}]")
|
||||
for split in ("train", "val", "test"):
|
||||
images_copied, boxes_copied = split_map.get(split, (0, 0))
|
||||
total_images += images_copied
|
||||
total_boxes += boxes_copied
|
||||
print(f" {split}: images={images_copied} boxes={boxes_copied}")
|
||||
|
||||
print(f"Total images: {total_images}")
|
||||
print(f"Total boxes: {total_boxes}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
155
08_train_compare_models.py
Normal file
155
08_train_compare_models.py
Normal file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Train multiple YOLO models on the merged shoe dataset and compare ROI performance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent
|
||||
PYDEPS = REPO_ROOT / ".pydeps"
|
||||
if PYDEPS.exists():
|
||||
sys.path.insert(0, str(PYDEPS))
|
||||
|
||||
from ultralytics import YOLO # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_MODELS = ["yolov8s.pt", "yolo11s.pt", "yolo26s.pt"]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train and compare YOLO shoe detectors")
|
||||
parser.add_argument("--data", default="datasets/shoe-public-mix/data.yaml", help="Training dataset yaml")
|
||||
parser.add_argument("--roi-dir", default="datasets/roi-shoes", help="Real-world ROI evaluation directory")
|
||||
parser.add_argument("--project", default="runs/shoe_compare", help="Ultralytics project directory")
|
||||
parser.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Model checkpoints or aliases")
|
||||
parser.add_argument("--epochs", type=int, default=40, help="Training epochs for each model")
|
||||
parser.add_argument("--imgsz", type=int, default=640, help="Training image size")
|
||||
parser.add_argument("--batch", type=int, default=16, help="Training batch size")
|
||||
parser.add_argument("--workers", type=int, default=8, help="Data loader workers")
|
||||
parser.add_argument("--device", default="0", help="Training device, e.g. 0 or cpu")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--patience", type=int, default=20, help="Early stopping patience")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def safe_name(model_name: str) -> str:
|
||||
stem = Path(model_name).stem
|
||||
return stem.replace(".", "_").replace("-", "_")
|
||||
|
||||
|
||||
def evaluate_roi(model_path: Path, roi_dir: Path, save_dir: Path, device: str) -> dict:
|
||||
model = YOLO(str(model_path))
|
||||
results = model.predict(
|
||||
source=str(roi_dir),
|
||||
conf=0.1,
|
||||
save=True,
|
||||
project=str(save_dir.parent),
|
||||
name=save_dir.name,
|
||||
exist_ok=True,
|
||||
verbose=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
per_image = []
|
||||
max_confs = []
|
||||
hit_count = 0
|
||||
for result in results:
|
||||
n = 0 if result.boxes is None else len(result.boxes)
|
||||
hit = n > 0
|
||||
if hit:
|
||||
hit_count += 1
|
||||
max_confs.append(float(result.boxes.conf.max().item()))
|
||||
per_image.append(
|
||||
{
|
||||
"image": Path(result.path).name,
|
||||
"detections": n,
|
||||
"max_conf": float(result.boxes.conf.max().item()) if hit else 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
summary = {
|
||||
"roi_total": len(per_image),
|
||||
"roi_hits": hit_count,
|
||||
"roi_hit_rate": hit_count / len(per_image) if per_image else 0.0,
|
||||
"roi_mean_max_conf": mean(max_confs) if max_confs else 0.0,
|
||||
"per_image": per_image,
|
||||
}
|
||||
return summary
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
os.environ.setdefault("YOLO_CONFIG_DIR", str(REPO_ROOT / ".ultralytics"))
|
||||
|
||||
data = Path(args.data)
|
||||
roi_dir = Path(args.roi_dir)
|
||||
project = Path(args.project)
|
||||
project.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not data.exists():
|
||||
raise FileNotFoundError(f"Dataset yaml not found: {data}")
|
||||
if not roi_dir.exists():
|
||||
raise FileNotFoundError(f"ROI directory not found: {roi_dir}")
|
||||
|
||||
compare_summary = []
|
||||
for model_name in args.models:
|
||||
run_name = f"{safe_name(model_name)}_shoe_{args.imgsz}"
|
||||
print("=" * 80)
|
||||
print(f"Training {model_name} -> {run_name}")
|
||||
print("=" * 80)
|
||||
|
||||
model = YOLO(model_name)
|
||||
train_results = model.train(
|
||||
data=str(data.resolve()),
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
workers=args.workers,
|
||||
device=args.device,
|
||||
seed=args.seed,
|
||||
patience=args.patience,
|
||||
project=str(project.resolve()),
|
||||
name=run_name,
|
||||
exist_ok=True,
|
||||
pretrained=True,
|
||||
cos_lr=True,
|
||||
amp=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
best_path = Path(train_results.save_dir) / "weights" / "best.pt"
|
||||
roi_save_dir = Path(train_results.save_dir) / "roi_eval"
|
||||
roi_summary = evaluate_roi(best_path, roi_dir, roi_save_dir, args.device)
|
||||
|
||||
record = {
|
||||
"model": model_name,
|
||||
"run_name": run_name,
|
||||
"save_dir": str(Path(train_results.save_dir).resolve()),
|
||||
"best_pt": str(best_path.resolve()),
|
||||
"metrics": {
|
||||
"map50": float(train_results.results_dict.get("metrics/mAP50(B)", 0.0)),
|
||||
"map50_95": float(train_results.results_dict.get("metrics/mAP50-95(B)", 0.0)),
|
||||
"precision": float(train_results.results_dict.get("metrics/precision(B)", 0.0)),
|
||||
"recall": float(train_results.results_dict.get("metrics/recall(B)", 0.0)),
|
||||
},
|
||||
"roi_eval": roi_summary,
|
||||
}
|
||||
compare_summary.append(record)
|
||||
|
||||
summary_path = Path(train_results.save_dir) / "roi_eval_summary.json"
|
||||
summary_path.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
compare_json = project / "compare_summary.json"
|
||||
compare_json.write_text(json.dumps(compare_summary, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
print(json.dumps(record, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
10_run_shoe_compare.ps1
Normal file
19
10_run_shoe_compare.ps1
Normal file
@ -0,0 +1,19 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||
$env:PYTHONPATH = "$repo\.pydeps"
|
||||
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||
|
||||
Set-Location $repo
|
||||
|
||||
py -3.11 .\07_build_public_shoe_dataset.py --clean
|
||||
|
||||
py -3.11 .\08_train_compare_models.py `
|
||||
--data "$repo\datasets\shoe-public-mix\data.yaml" `
|
||||
--roi-dir "$repo\datasets\roi-shoes" `
|
||||
--project "$repo\runs\shoe_compare" `
|
||||
--models "yolov8s.pt" "yolo11s.pt" "yolo26s.pt" `
|
||||
--epochs 40 `
|
||||
--imgsz 640 `
|
||||
--batch 16 `
|
||||
--device 0
|
||||
17
11_run_shoe_compare_960.ps1
Normal file
17
11_run_shoe_compare_960.ps1
Normal file
@ -0,0 +1,17 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||
$env:PYTHONPATH = "$repo\.pydeps"
|
||||
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||
|
||||
Set-Location $repo
|
||||
|
||||
py -3.11 .\08_train_compare_models.py `
|
||||
--data "$repo\datasets\shoe-public-mix\data.yaml" `
|
||||
--roi-dir "$repo\datasets\roi-shoes" `
|
||||
--project "$repo\runs\shoe_compare_960" `
|
||||
--models "yolov8s.pt" "yolo11s.pt" "yolo26s.pt" `
|
||||
--epochs 40 `
|
||||
--imgsz 960 `
|
||||
--batch 8 `
|
||||
--device 0
|
||||
156
README.md
Normal file
156
README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# 鞋子检测模型训练指南
|
||||
|
||||
## 方案:640x640 单模型(部署时用2窗口)
|
||||
|
||||
**训练阶段**:
|
||||
- 输入:640x640 完整图片
|
||||
- 模型:YOLOv8s
|
||||
- 输出:640x640 模型文件
|
||||
|
||||
**部署阶段**(pipeline配置):
|
||||
- 原图 1920x1080
|
||||
- 分成 2 个 960x1080 窗口
|
||||
- 每个窗口 resize 到 640x640 送入模型
|
||||
- 合并检测结果
|
||||
|
||||
---
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
train/
|
||||
├── README.md # 本文件
|
||||
├── 01_download_dataset.py # 下载鞋子数据集(推荐 Open Images)
|
||||
├── 02_train.bat # Windows 一键训练脚本
|
||||
├── 03_export_onnx.bat # 导出 ONNX 脚本
|
||||
├── 04_convert_rknn.py # 转换为 RKNN 脚本
|
||||
├── 05_prepare_ppe_shoe_subset.py # 提取 PPE 鞋子单类子集
|
||||
├── 06_finetune_ppe.bat # 用 PPE 鞋子子集做二阶段微调
|
||||
├── data.yaml.template # 数据集配置文件
|
||||
└── samples/ # 示例图片
|
||||
├── calibration/
|
||||
├── test_images/
|
||||
└── README.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 下载数据集
|
||||
|
||||
```bash
|
||||
cd train
|
||||
python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
```
|
||||
|
||||
### 2. 准备配置
|
||||
|
||||
```bash
|
||||
脚本会自动生成 datasets/openimages-shoes-yolo/data.yaml
|
||||
```
|
||||
|
||||
### 3. 训练(640x640)
|
||||
|
||||
```bash
|
||||
02_train.bat
|
||||
```
|
||||
|
||||
或手动:
|
||||
```bash
|
||||
yolo detect train \
|
||||
data=datasets/openimages-shoes-yolo/data.yaml \
|
||||
model=yolov8s.pt \
|
||||
epochs=150 \
|
||||
imgsz=640 \
|
||||
batch=16 \
|
||||
device=0
|
||||
```
|
||||
|
||||
**训练参数**:
|
||||
- 模型:YOLOv8s(速度和精度平衡)
|
||||
- 输入:640x640
|
||||
- 预计时间:30-60分钟
|
||||
|
||||
### 4. 导出 ONNX
|
||||
|
||||
```bash
|
||||
03_export_onnx.bat
|
||||
```
|
||||
|
||||
### 5. 转换为 RKNN
|
||||
|
||||
在 Ubuntu PC 上:
|
||||
```bash
|
||||
python 04_convert_rknn.py runs/detect/train/weights/best.onnx -o shoe_detector_640.rknn -t rk3588
|
||||
```
|
||||
|
||||
### 6. 部署(2窗口配置)
|
||||
|
||||
复制到 RK3588:
|
||||
```bash
|
||||
scp shoe_detector_640.rknn orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/
|
||||
```
|
||||
|
||||
Pipeline 配置(部署阶段用2窗口):
|
||||
```json
|
||||
{
|
||||
"id": "pre_shoe",
|
||||
"type": "preprocess",
|
||||
"windows": [
|
||||
{"x": 0, "y": 0, "w": 960, "h": 1080},
|
||||
{"x": 960, "y": 0, "w": 960, "h": 1080}
|
||||
],
|
||||
"dst_w": 640,
|
||||
"dst_h": 640
|
||||
}
|
||||
```
|
||||
|
||||
### 7. 方案 A:PPE 二阶段微调
|
||||
|
||||
当 Open Images 基础模型训练完成后,可继续用 PPE 鞋子子集做场景微调:
|
||||
|
||||
```bash
|
||||
python 05_prepare_ppe_shoe_subset.py
|
||||
06_finetune_ppe.bat
|
||||
```
|
||||
|
||||
PPE 鞋子子集来源:
|
||||
- `boots`
|
||||
- `no_boots`
|
||||
|
||||
这两个类会统一映射成单类:
|
||||
- `shoe`
|
||||
|
||||
---
|
||||
|
||||
## 类别说明(Open Images)
|
||||
|
||||
Open Images 官方鞋类层级中,`Footwear` 的子类包括:
|
||||
- `Boot`
|
||||
- `Sandal`
|
||||
- `High heels`
|
||||
- `Roller skates`
|
||||
|
||||
本项目推荐下载:
|
||||
- `Footwear`
|
||||
- `Boot`
|
||||
|
||||
可选补充:
|
||||
- `Sandal`
|
||||
|
||||
不建议默认加入:
|
||||
- `High heels`
|
||||
- `Roller skates`
|
||||
|
||||
训练时统一映射为单一类别:
|
||||
- `0: shoe`
|
||||
|
||||
这样模型目标更聚焦,先尽量把鞋子稳定检出,再在后处理里判断是否为黑色鞋。
|
||||
|
||||
---
|
||||
|
||||
## 相关链接
|
||||
|
||||
- [Open Images 数据集](https://storage.googleapis.com/openimages/web/index.html)
|
||||
- [Ultralytics YOLOv8](https://docs.ultralytics.com/)
|
||||
20
data.yaml.template
Normal file
20
data.yaml.template
Normal file
@ -0,0 +1,20 @@
|
||||
# 单类鞋子检测数据集配置
|
||||
|
||||
path: .
|
||||
train: images/train
|
||||
val: images/val
|
||||
test: images/test
|
||||
|
||||
# 单类 shoe
|
||||
nc: 1
|
||||
names: ['shoe']
|
||||
|
||||
# 推荐来源:
|
||||
# python 01_download_dataset.py --source openimages --max-samples 5000
|
||||
#
|
||||
# 推荐 Open Images 子类:
|
||||
# - Footwear
|
||||
# - Boot
|
||||
#
|
||||
# 可选补充:
|
||||
# - Sandal
|
||||
46
samples/README.md
Normal file
46
samples/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# 示例图片目录
|
||||
|
||||
用于存放测试图片和量化校准样本。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
samples/
|
||||
├── test_images/ # 用于测试模型的示例图片
|
||||
├── calibration/ # INT8 量化校准用的图片(约 20-100 张)
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 测试图片 (test_images/)
|
||||
|
||||
存放一些典型场景的鞋子图片,用于验证模型效果。
|
||||
|
||||
### 校准图片 (calibration/)
|
||||
|
||||
INT8 量化时需要,用于确定量化参数。
|
||||
|
||||
**要求:**
|
||||
- 应与实际部署场景相似
|
||||
- 包含各种光照、角度、背景的样本
|
||||
- 建议 20-100 张
|
||||
- 图片格式: JPG, PNG, BMP
|
||||
|
||||
**创建校准数据集文件:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS
|
||||
ls samples/calibration/*.jpg > dataset.txt
|
||||
ls samples/calibration/*.png >> dataset.txt
|
||||
|
||||
# Windows CMD
|
||||
dir /b samples\calibration\*.jpg > dataset.txt
|
||||
dir /b samples\calibration\*.png >> dataset.txt
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 校准图片越多,转换时间越长,但精度可能更高
|
||||
- 建议使用训练集的部分图片作为校准集
|
||||
- 不要和测试集重复,避免过拟合
|
||||
Loading…
Reference in New Issue
Block a user