增加训练模型需要的目录和脚本
This commit is contained in:
parent
9658b29fca
commit
9de41aed90
@ -105,14 +105,14 @@
|
||||
"model_w": 768,
|
||||
"model_h": 768,
|
||||
"num_classes": 11,
|
||||
"conf": 0.2,
|
||||
"conf": 0.1,
|
||||
"nms": 0.45,
|
||||
"debug": {
|
||||
"stats": true,
|
||||
"stats_interval": 30,
|
||||
"detections": true
|
||||
},
|
||||
"class_filter": [3, 6]
|
||||
"class_filter": [3, 6, 10]
|
||||
},
|
||||
{
|
||||
"id": "tracker",
|
||||
|
||||
156
train/01_download_dataset.py
Normal file
156
train/01_download_dataset.py
Normal file
@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
下载 Roboflow Safety Shoes Detection 数据集
|
||||
使用方法:
|
||||
python 01_download_dataset.py --api-key YOUR_API_KEY
|
||||
|
||||
或者手动下载:
|
||||
1. 访问 https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg
|
||||
2. 点击 Download → 选择 YOLOv8 格式
|
||||
3. 解压到 datasets/ 目录
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def download_with_roboflow(api_key: str, dataset_dir: str = "datasets"):
|
||||
"""使用 Roboflow API 下载数据集"""
|
||||
try:
|
||||
from roboflow import Roboflow
|
||||
except ImportError:
|
||||
print("错误: 未安装 roboflow 包")
|
||||
print("请运行: pip install roboflow")
|
||||
sys.exit(1)
|
||||
|
||||
print("="*60)
|
||||
print("正在下载 Safety Shoes Detection 数据集...")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
rf = Roboflow(api_key=api_key)
|
||||
project = rf.workspace("nedrick-chandra-gpg1l").project("safety-shoes-detection-5qgkg")
|
||||
dataset = project.version(2).download("yolov8", location=dataset_dir)
|
||||
|
||||
print(f"\n✓ 数据集下载完成: {dataset.location}")
|
||||
return dataset.location
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 下载失败: {e}")
|
||||
print("\n请尝试手动下载:")
|
||||
print("1. 访问 https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg")
|
||||
print("2. 点击 'Download' → 选择 'YOLOv8' 格式")
|
||||
print("3. 解压到 datasets/ 目录")
|
||||
return None
|
||||
|
||||
|
||||
def modify_yaml_for_single_class(dataset_path: str):
|
||||
"""修改为单类检测配置"""
|
||||
yaml_path = os.path.join(dataset_path, "data.yaml")
|
||||
|
||||
if not os.path.exists(yaml_path):
|
||||
print(f"警告: 找不到 {yaml_path}")
|
||||
return False
|
||||
|
||||
with open(yaml_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# 创建新的单类配置
|
||||
new_content = """# 单类鞋子检测数据集配置
|
||||
# 原数据集: Safety Shoes Detection (Roboflow)
|
||||
# 修改: 合并 safety-shoes 和 no-safety-shoes 为单一的 shoe 类别
|
||||
|
||||
train: ../train/images
|
||||
val: ../valid/images
|
||||
test: ../test/images
|
||||
|
||||
nc: 1
|
||||
names: ['shoe']
|
||||
|
||||
# Roboflow 元信息
|
||||
roboflow:
|
||||
workspace: nedrick-chandra-gpg1l
|
||||
project: safety-shoes-detection-5qgkg
|
||||
version: 2
|
||||
license: CC BY 4.0
|
||||
url: https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg/dataset/2
|
||||
"""
|
||||
|
||||
# 备份原文件
|
||||
backup_path = yaml_path + ".backup"
|
||||
with open(backup_path, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
# 写入新配置
|
||||
with open(yaml_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
print(f"✓ 已修改为单类检测: {yaml_path}")
|
||||
print(f" 原配置备份: {backup_path}")
|
||||
return True
|
||||
|
||||
|
||||
def check_dataset_structure(dataset_path: str):
|
||||
"""检查数据集结构是否正确"""
|
||||
required_dirs = ['train/images', 'train/labels', 'valid/images', 'valid/labels']
|
||||
|
||||
print("\n检查数据集结构...")
|
||||
for dir_name in required_dirs:
|
||||
full_path = os.path.join(dataset_path, dir_name)
|
||||
if os.path.exists(full_path):
|
||||
count = len(os.listdir(full_path))
|
||||
print(f" ✓ {dir_name}: {count} 个文件")
|
||||
else:
|
||||
print(f" ✗ {dir_name}: 不存在")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="下载 Safety Shoes Detection 数据集")
|
||||
parser.add_argument("--api-key", help="Roboflow API Key")
|
||||
parser.add_argument("--dir", default="datasets/safety-shoes-detection",
|
||||
help="数据集保存目录")
|
||||
parser.add_argument("--no-modify", action="store_true",
|
||||
help="不修改 data.yaml(保持原始类别)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果提供了 API key,使用 API 下载
|
||||
if args.api_key:
|
||||
dataset_path = download_with_roboflow(args.api_key, args.dir)
|
||||
if dataset_path is None:
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 检查是否已手动下载
|
||||
dataset_path = args.dir
|
||||
if not os.path.exists(dataset_path):
|
||||
print(f"错误: 找不到数据集目录 {dataset_path}")
|
||||
print("\n请使用以下方式之一获取数据集:")
|
||||
print("1. 使用 API 下载: python 01_download_dataset.py --api-key YOUR_KEY")
|
||||
print("2. 手动下载并解压到: datasets/safety-shoes-detection/")
|
||||
sys.exit(1)
|
||||
|
||||
# 检查数据集结构
|
||||
if not check_dataset_structure(dataset_path):
|
||||
print("\n✗ 数据集结构不正确")
|
||||
sys.exit(1)
|
||||
|
||||
# 修改为单类检测
|
||||
if not args.no_modify:
|
||||
modify_yaml_for_single_class(dataset_path)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("数据集准备完成!")
|
||||
print("="*60)
|
||||
print(f"数据集路径: {dataset_path}")
|
||||
print(f"配置文件: {dataset_path}/data.yaml")
|
||||
print("\n下一步:")
|
||||
print(f" yolo detect train data={dataset_path}/data.yaml model=yolov8n.pt epochs=150 imgsz=640")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
93
train/02_train.bat
Normal file
93
train/02_train.bat
Normal file
@ -0,0 +1,93 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
echo ============================================================
|
||||
echo 训练鞋子检测模型 (YOLOv8)
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 设置数据集路径
|
||||
set DATASET=datasets/safety-shoes-detection/data.yaml
|
||||
|
||||
:: 检查数据集是否存在
|
||||
if not exist %DATASET% (
|
||||
echo [错误] 找不到数据集配置文件: %DATASET%
|
||||
echo.
|
||||
echo 请先下载数据集:
|
||||
echo 1. 访问 https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg
|
||||
echo 2. 点击 Download -^> YOLOv8 格式
|
||||
echo 3. 解压到 datasets/safety-shoes-detection/
|
||||
echo 4. 运行 python 01_download_dataset.py --no-modify
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [信息] 数据集: %DATASET%
|
||||
echo.
|
||||
|
||||
:: 选择模型
|
||||
echo 选择模型:
|
||||
echo 1. YOLOv8n (轻量级, 速度快, 推荐)
|
||||
echo 2. YOLOv8s (精度更高, 稍慢)
|
||||
echo 3. YOLOv8m (高精度, 较慢)
|
||||
echo.
|
||||
set /p MODEL_CHOICE="输入选择 (1-3, 默认 1): "
|
||||
|
||||
if "%MODEL_CHOICE%"=="" set MODEL_CHOICE=1
|
||||
if "%MODEL_CHOICE%"=="1" (
|
||||
set MODEL=yolov8n.pt
|
||||
set DESC=YOLOv8n (轻量级)
|
||||
)
|
||||
if "%MODEL_CHOICE%"=="2" (
|
||||
set MODEL=yolov8s.pt
|
||||
set DESC=YOLOv8s (标准)
|
||||
)
|
||||
if "%MODEL_CHOICE%"=="3" (
|
||||
set MODEL=yolov8m.pt
|
||||
set DESC=YOLOv8m (高精度)
|
||||
)
|
||||
|
||||
echo.
|
||||
echo [信息] 使用模型: %DESC%
|
||||
echo.
|
||||
|
||||
:: 设置训练参数
|
||||
set EPOCHS=150
|
||||
set IMGSZ=640
|
||||
set BATCH=16
|
||||
|
||||
echo 训练参数:
|
||||
echo - Epochs: %EPOCHS%
|
||||
echo - Image Size: %IMGSZ%
|
||||
echo - Batch Size: %BATCH%
|
||||
echo - Device: GPU (cuda:0)
|
||||
echo.
|
||||
|
||||
echo ============================================================
|
||||
echo 开始训练
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 开始训练
|
||||
yolo detect train data=%DATASET% model=%MODEL% epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [错误] 训练失败!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 训练完成!
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo 模型保存在: runs/detect/train/weights/
|
||||
echo - best.pt (最佳模型)
|
||||
echo - last.pt (最后模型)
|
||||
echo.
|
||||
echo 下一步: 运行 03_export_onnx.bat 导出 ONNX 格式
|
||||
echo.
|
||||
pause
|
||||
75
train/03_export_onnx.bat
Normal file
75
train/03_export_onnx.bat
Normal file
@ -0,0 +1,75 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
cls
|
||||
|
||||
echo ============================================================
|
||||
echo 导出 ONNX 模型 (YOLOv8)
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 设置模型路径
|
||||
set MODEL_PATH=runs/detect/train/weights/best.pt
|
||||
|
||||
:: 检查模型是否存在
|
||||
if not exist %MODEL_PATH% (
|
||||
echo [错误] 找不到模型文件: %MODEL_PATH%
|
||||
echo.
|
||||
echo 请先训练模型:
|
||||
echo 运行 02_train.bat
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo [信息] 输入模型: %MODEL_PATH%
|
||||
echo.
|
||||
|
||||
:: 导出 ONNX
|
||||
echo ============================================================
|
||||
echo 导出 ONNX
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
yolo export model=%MODEL_PATH% format=onnx imgsz=640 opset=12 simplify=True
|
||||
|
||||
if %ERRORLEVEL% neq 0 (
|
||||
echo.
|
||||
echo [错误] 导出失败!
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 导出完成!
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
:: 检查输出文件
|
||||
set ONNX_PATH=runs\detect\train\weights\best.onnx
|
||||
if exist %ONNX_PATH% (
|
||||
echo [成功] ONNX 模型: %ONNX_PATH%
|
||||
|
||||
:: 获取文件大小
|
||||
for %%I in (%ONNX_PATH%) do (
|
||||
set SIZE=%%~zI
|
||||
)
|
||||
echo [信息] 文件大小: %SIZE% bytes
|
||||
) else (
|
||||
echo [警告] 找不到输出文件
|
||||
)
|
||||
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 下一步操作
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo 1. 复制 ONNX 文件到 Ubuntu 机器:
|
||||
echo scp %ONNX_PATH% user@ubuntu-pc:~/rknn_convert/
|
||||
echo.
|
||||
echo 2. 在 Ubuntu 上转换为 RKNN:
|
||||
echo python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588
|
||||
echo.
|
||||
echo 3. 部署到 RK3588:
|
||||
echo scp shoe_detector.rknn orangepi@^<rk3588_ip^>:/home/orangepi/apps/OrangePi3588Media/models/
|
||||
echo.
|
||||
pause
|
||||
262
train/04_convert_rknn.py
Normal file
262
train/04_convert_rknn.py
Normal file
@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
将 YOLOv8 ONNX 模型转换为 RKNN 格式
|
||||
适用于 RK3588 / RK3568 / RK3576 等平台
|
||||
|
||||
环境要求:
|
||||
- Ubuntu x86_64 / Docker
|
||||
- Python 3.8 / 3.9 / 3.10 / 3.11
|
||||
- rknn-toolkit2 (pip install rknn-toolkit2==2.2.0)
|
||||
|
||||
使用方法:
|
||||
# FP16 模式(推荐,速度快精度高)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588
|
||||
|
||||
# INT8 量化(模型更小,需要校准数据集)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588 -q -d dataset.txt
|
||||
|
||||
支持的 target_platform:
|
||||
- rk3588 / rk3588s
|
||||
- rk3568 / rk3566
|
||||
- rk3576
|
||||
- rv1106 / rv1103 / rv1103b
|
||||
- rv1126
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_environment():
|
||||
"""检查运行环境"""
|
||||
try:
|
||||
from rknn.api import RKNN
|
||||
print("✓ RKNN Toolkit2 已安装")
|
||||
return True
|
||||
except ImportError:
|
||||
print("✗ 错误: 未安装 RKNN Toolkit2")
|
||||
print("\n请安装:")
|
||||
print(" pip install rknn-toolkit2==2.2.0")
|
||||
print("\n或从源码安装:")
|
||||
print(" https://github.com/airockchip/rknn-toolkit2")
|
||||
return False
|
||||
|
||||
|
||||
def create_sample_dataset(onnx_path: str, output_path: str = "dataset.txt", num_samples: int = 20):
|
||||
"""
|
||||
创建示例量化校准数据集
|
||||
用于 INT8 量化时提供校准图片路径
|
||||
"""
|
||||
print(f"\n创建示例校准数据集: {output_path}")
|
||||
print("注意: 请用实际图片替换这些示例路径")
|
||||
|
||||
sample_content = f"""# RKNN INT8 量化校准数据集
|
||||
# 每行一个图片路径,建议使用 20-100 张典型场景图片
|
||||
# 图片格式: JPG, PNG, BMP 等
|
||||
|
||||
# 示例路径(请替换为实际路径):
|
||||
# /path/to/train/images/img001.jpg
|
||||
# /path/to/train/images/img002.jpg
|
||||
# /path/to/valid/images/img001.jpg
|
||||
|
||||
# 提示:
|
||||
# 1. 图片应与实际部署场景相似
|
||||
# 2. 包含各种光照、角度、背景的样本
|
||||
# 3. 建议 20-100 张,越多越慢但可能更准
|
||||
"""
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
f.write(sample_content)
|
||||
|
||||
print(f"✓ 示例数据集已创建: {output_path}")
|
||||
print(" 请编辑此文件,添加实际的图片路径")
|
||||
return output_path
|
||||
|
||||
|
||||
def convert_onnx_to_rknn(
|
||||
onnx_path: str,
|
||||
output_path: str = None,
|
||||
target_platform: str = "rk3588",
|
||||
do_quantization: bool = False,
|
||||
dataset_path: str = None,
|
||||
verbose: bool = True
|
||||
):
|
||||
"""
|
||||
转换 ONNX 模型到 RKNN
|
||||
|
||||
Args:
|
||||
onnx_path: ONNX 模型文件路径
|
||||
output_path: 输出 RKNN 文件路径,默认与 ONNX 同名
|
||||
target_platform: 目标平台,默认 rk3588
|
||||
do_quantization: 是否启用 INT8 量化
|
||||
dataset_path: 量化校准数据集路径(txt 文件,每行一张图片路径)
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
if output_path is None:
|
||||
output_path = onnx_path.replace(".onnx", ".rknn")
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
print("="*70)
|
||||
print(f"ONNX 转 RKNN")
|
||||
print("="*70)
|
||||
print(f"输入: {onnx_path}")
|
||||
print(f"输出: {output_path}")
|
||||
print(f"目标: {target_platform}")
|
||||
print(f"量化: {'INT8' if do_quantization else 'FP16 (无量化)'}")
|
||||
if do_quantization:
|
||||
print(f"校准: {dataset_path}")
|
||||
print("="*70)
|
||||
|
||||
# 检查输入文件
|
||||
if not os.path.exists(onnx_path):
|
||||
print(f"\n✗ 错误: 找不到 ONNX 文件: {onnx_path}")
|
||||
return False
|
||||
|
||||
# 检查数据集(如果需要量化)
|
||||
if do_quantization:
|
||||
if dataset_path is None:
|
||||
print("\n✗ 错误: INT8 量化需要提供校准数据集")
|
||||
print(" 使用 --dataset 指定数据集文件路径")
|
||||
print(" 或运行 --create-dataset 创建示例")
|
||||
return False
|
||||
if not os.path.exists(dataset_path):
|
||||
print(f"\n✗ 错误: 找不到数据集文件: {dataset_path}")
|
||||
return False
|
||||
|
||||
from rknn.api import RKNN
|
||||
|
||||
# 创建 RKNN 对象
|
||||
rknn = RKNN(verbose=verbose)
|
||||
|
||||
try:
|
||||
# 配置模型
|
||||
print("\n[1/4] 配置模型...")
|
||||
rknn.config(
|
||||
mean_values=[[0, 0, 0]], # YOLOv8 使用 0-255 输入
|
||||
std_values=[[255, 255, 255]], # 归一化到 0-1
|
||||
target_platform=target_platform
|
||||
)
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 加载 ONNX
|
||||
print("\n[2/4] 加载 ONNX 模型...")
|
||||
ret = rknn.load_onnx(model=onnx_path)
|
||||
if ret != 0:
|
||||
print(" ✗ 加载失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 构建模型
|
||||
print("\n[3/4] 构建 RKNN 模型...")
|
||||
if do_quantization:
|
||||
print(f" 使用 INT8 量化,校准数据集: {dataset_path}")
|
||||
ret = rknn.build(do_quantization=True, dataset=dataset_path)
|
||||
else:
|
||||
print(" 使用 FP16 模式(无量化)")
|
||||
ret = rknn.build(do_quantization=False)
|
||||
|
||||
if ret != 0:
|
||||
print(" ✗ 构建失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
# 导出 RKNN
|
||||
print("\n[4/4] 导出 RKNN 模型...")
|
||||
ret = rknn.export_rknn(output_path)
|
||||
if ret != 0:
|
||||
print(" ✗ 导出失败!")
|
||||
return False
|
||||
print(" ✓ 完成")
|
||||
|
||||
finally:
|
||||
rknn.release()
|
||||
|
||||
# 验证输出
|
||||
if os.path.exists(output_path):
|
||||
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
||||
print("\n" + "="*70)
|
||||
print(f"✓ 转换成功!")
|
||||
print(f" 输出文件: {output_path}")
|
||||
print(f" 文件大小: {size_mb:.2f} MB")
|
||||
print("="*70)
|
||||
|
||||
print("\n下一步:")
|
||||
print(f" 1. 复制到 RK3588:")
|
||||
print(f" scp {output_path} orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/")
|
||||
print(f" 2. 更新配置文件中的模型路径")
|
||||
return True
|
||||
else:
|
||||
print("\n✗ 错误: 输出文件未生成")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="将 YOLOv8 ONNX 模型转换为 RKNN",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# FP16 模式(推荐)
|
||||
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn
|
||||
|
||||
# 指定目标平台
|
||||
python 04_convert_rknn.py best.onnx -t rk3568
|
||||
|
||||
# INT8 量化
|
||||
python 04_convert_rknn.py best.onnx -q -d dataset.txt
|
||||
|
||||
# 创建示例校准数据集
|
||||
python 04_convert_rknn.py --create-dataset
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("onnx", nargs="?", help="ONNX 模型文件路径")
|
||||
parser.add_argument("-o", "--output", help="输出 RKNN 文件路径")
|
||||
parser.add_argument("-t", "--target", default="rk3588",
|
||||
help="目标平台 (默认: rk3588)")
|
||||
parser.add_argument("-q", "--quantize", action="store_true",
|
||||
help="启用 INT8 量化")
|
||||
parser.add_argument("-d", "--dataset", help="量化校准数据集路径 (txt 文件)")
|
||||
parser.add_argument("--create-dataset", action="store_true",
|
||||
help="创建示例校准数据集并退出")
|
||||
parser.add_argument("-v", "--verbose", action="store_true", default=True,
|
||||
help="显示详细信息")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建示例数据集
|
||||
if args.create_dataset:
|
||||
create_sample_dataset("best.onnx", "dataset.txt")
|
||||
return 0
|
||||
|
||||
# 检查参数
|
||||
if args.onnx is None:
|
||||
parser.print_help()
|
||||
print("\n错误: 请提供 ONNX 文件路径")
|
||||
return 1
|
||||
|
||||
# 检查环境
|
||||
if not check_environment():
|
||||
return 1
|
||||
|
||||
# 执行转换
|
||||
success = convert_onnx_to_rknn(
|
||||
onnx_path=args.onnx,
|
||||
output_path=args.output,
|
||||
target_platform=args.target,
|
||||
do_quantization=args.quantize,
|
||||
dataset_path=args.dataset,
|
||||
verbose=args.verbose
|
||||
)
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
157
train/README.md
Normal file
157
train/README.md
Normal file
@ -0,0 +1,157 @@
|
||||
# 鞋子检测模型训练指南
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
train/
|
||||
├── README.md # 本文件
|
||||
├── 01_download_dataset.py # 下载数据集脚本
|
||||
├── 02_train.bat # Windows 训练脚本
|
||||
├── 03_export_onnx.bat # 导出 ONNX 脚本
|
||||
├── 04_convert_rknn.py # 转换为 RKNN 脚本
|
||||
├── data.yaml.template # 数据集配置文件模板
|
||||
└── samples/ # 示例图片(用于测试)
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 环境准备(Windows + GPU)
|
||||
|
||||
```bash
|
||||
# 安装 PyTorch (CUDA 11.8)
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# 安装 ultralytics
|
||||
pip install ultralytics
|
||||
```
|
||||
|
||||
### 2. 下载数据集
|
||||
|
||||
**手动下载(推荐):**
|
||||
1. 访问:https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg
|
||||
2. 点击 **"Download"** → 选择 **"YOLOv8"** 格式
|
||||
3. 解压到 `datasets/safety-shoes-detection/` 目录
|
||||
|
||||
**或使用脚本(需要 API Key):**
|
||||
```bash
|
||||
python 01_download_dataset.py --api-key YOUR_API_KEY
|
||||
```
|
||||
|
||||
### 3. 准备数据集配置
|
||||
|
||||
复制模板并修改路径:
|
||||
```bash
|
||||
cp data.yaml.template datasets/safety-shoes-detection/data.yaml
|
||||
# 编辑 data.yaml,确保路径正确
|
||||
```
|
||||
|
||||
### 4. 训练模型
|
||||
|
||||
**一键训练:**
|
||||
```bash
|
||||
02_train.bat
|
||||
```
|
||||
|
||||
**或手动训练:**
|
||||
```bash
|
||||
# YOLOv8n - 轻量级,速度快
|
||||
yolo detect train data=datasets/safety-shoes-detection/data.yaml model=yolov8n.pt epochs=150 imgsz=640 batch=16 device=0
|
||||
|
||||
# YOLOv8s - 精度更高(可选)
|
||||
# yolo detect train data=datasets/safety-shoes-detection/data.yaml model=yolov8s.pt epochs=150 imgsz=640 batch=16 device=0
|
||||
```
|
||||
|
||||
训练完成后,模型保存在:`runs/detect/train/weights/best.pt`
|
||||
|
||||
### 5. 导出 ONNX
|
||||
|
||||
```bash
|
||||
03_export_onnx.bat
|
||||
```
|
||||
|
||||
输出:`runs/detect/train/weights/best.onnx`
|
||||
|
||||
### 6. 转换为 RKNN
|
||||
|
||||
**在 Ubuntu PC 上运行:**
|
||||
|
||||
```bash
|
||||
# 安装 RKNN Toolkit2
|
||||
pip install rknn-toolkit2==2.2.0
|
||||
|
||||
# 转换(FP16 模式 - 推荐)
|
||||
python 04_convert_rknn.py runs/detect/train/weights/best.onnx -o shoe_detector.rknn -t rk3588
|
||||
|
||||
# 或 INT8 量化(需要校准数据集)
|
||||
# python 04_convert_rknn.py runs/detect/train/weights/best.onnx -o shoe_detector.rknn -t rk3588 -q -d dataset.txt
|
||||
```
|
||||
|
||||
### 7. 部署到 RK3588
|
||||
|
||||
```bash
|
||||
scp shoe_detector.rknn orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/
|
||||
```
|
||||
|
||||
然后在 `configs/full_pipeline_1080p.json` 中更新模型路径。
|
||||
|
||||
---
|
||||
|
||||
## 训练参数说明
|
||||
|
||||
| 参数 | YOLOv8n | YOLOv8s | 说明 |
|
||||
|------|---------|---------|------|
|
||||
| 模型大小 | 3.2MB | 11MB | 文件大小 |
|
||||
| 推理速度 | ~30-40ms | ~50-60ms | RK3588 NPU |
|
||||
| mAP | ~0.75 | ~0.82 | 精度 |
|
||||
| 推荐场景 | 实时检测 | 高精度 | 选择建议 |
|
||||
|
||||
---
|
||||
|
||||
## 数据集说明
|
||||
|
||||
### Safety Shoes Detection
|
||||
- **来源**: Roboflow Universe
|
||||
- **类别**: safety-shoes / no-safety-shoes
|
||||
- **图片数**: 约 1000+ 张
|
||||
- **场景**: 工地安全鞋检测
|
||||
|
||||
### 转换为单类检测
|
||||
|
||||
我们将两类合并为单一的 `shoe` 类别:
|
||||
- 检测所有鞋子(安全鞋、运动鞋、布鞋等)
|
||||
- 后续通过颜色分析判断是否为劳保鞋
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q1: 训练时显存不足?
|
||||
降低 batch size:
|
||||
```bash
|
||||
yolo detect train ... batch=8 # 默认 16,改为 8
|
||||
```
|
||||
|
||||
### Q2: 如何提高精度?
|
||||
1. 增加训练 epoch:`epochs=200`
|
||||
2. 使用更大模型:`model=yolov8s.pt`
|
||||
3. 增大输入尺寸:`imgsz=768`
|
||||
4. 收集更多现场图片 fine-tune
|
||||
|
||||
### Q3: RKNN 转换失败?
|
||||
1. 确保使用正确的 opset (12)
|
||||
2. 使用 `simplify=True` 导出 ONNX
|
||||
3. 检查 RKNN Toolkit2 版本与板端驱动匹配
|
||||
|
||||
### Q4: 检测不到鞋子?
|
||||
1. 降低置信度阈值:`conf=0.15`
|
||||
2. 检查 class_filter 是否正确设置
|
||||
3. 确认输入图像尺寸与模型匹配
|
||||
|
||||
---
|
||||
|
||||
## 相关链接
|
||||
|
||||
- [Ultralytics YOLOv8 文档](https://docs.ultralytics.com/)
|
||||
- [RKNN Toolkit2 文档](https://github.com/airockchip/rknn-toolkit2)
|
||||
- [Roboflow Universe - Safety Shoes](https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg)
|
||||
38
train/data.yaml.template
Normal file
38
train/data.yaml.template
Normal file
@ -0,0 +1,38 @@
|
||||
# 单类鞋子检测数据集配置
|
||||
# 基于 Roboflow Safety Shoes Detection 数据集修改
|
||||
# 将原有的两类 (safety-shoes / no-safety-shoes) 合并为单一的 shoe 类别
|
||||
|
||||
# 数据集路径
|
||||
train: ../train/images
|
||||
val: ../valid/images
|
||||
test: ../test/images
|
||||
|
||||
# 类别配置
|
||||
nc: 1 # 类别数
|
||||
names: ['shoe'] # 类别名称列表
|
||||
|
||||
# Roboflow 元信息(可选)
|
||||
roboflow:
|
||||
workspace: nedrick-chandra-gpg1l
|
||||
project: safety-shoes-detection-5qgkg
|
||||
version: 2
|
||||
license: CC BY 4.0
|
||||
url: https://universe.roboflow.com/nedrick-chandra-gpg1l/safety-shoes-detection-5qgkg/dataset/2
|
||||
|
||||
# 使用说明:
|
||||
# 1. 将此文件复制到数据集根目录,命名为 data.yaml
|
||||
# 2. 确保 train/val/test 路径正确
|
||||
# 3. 运行训练: yolo detect train data=data.yaml model=yolov8n.pt epochs=150 imgsz=640
|
||||
#
|
||||
# 目录结构应为:
|
||||
# safety-shoes-detection/
|
||||
# ├── data.yaml # 本文件
|
||||
# ├── train/
|
||||
# │ ├── images/ # 训练图片
|
||||
# │ └── labels/ # YOLO 格式标注文件
|
||||
# ├── valid/
|
||||
# │ ├── images/ # 验证图片
|
||||
# │ └── labels/ # YOLO 格式标注文件
|
||||
# └── test/
|
||||
# ├── images/ # 测试图片
|
||||
# └── labels/ # YOLO 格式标注文件
|
||||
46
train/samples/README.md
Normal file
46
train/samples/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# 示例图片目录
|
||||
|
||||
用于存放测试图片和量化校准样本。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
samples/
|
||||
├── test_images/ # 用于测试模型的示例图片
|
||||
├── calibration/ # INT8 量化校准用的图片(约 20-100 张)
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 测试图片 (test_images/)
|
||||
|
||||
存放一些典型场景的鞋子图片,用于验证模型效果。
|
||||
|
||||
### 校准图片 (calibration/)
|
||||
|
||||
INT8 量化时需要,用于确定量化参数。
|
||||
|
||||
**要求:**
|
||||
- 应与实际部署场景相似
|
||||
- 包含各种光照、角度、背景的样本
|
||||
- 建议 20-100 张
|
||||
- 图片格式: JPG, PNG, BMP
|
||||
|
||||
**创建校准数据集文件:**
|
||||
|
||||
```bash
|
||||
# Linux/macOS
|
||||
ls samples/calibration/*.jpg > dataset.txt
|
||||
ls samples/calibration/*.png >> dataset.txt
|
||||
|
||||
# Windows CMD
|
||||
dir /b samples\calibration\*.jpg > dataset.txt
|
||||
dir /b samples\calibration\*.png >> dataset.txt
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 校准图片越多,转换时间越长,但精度可能更高
|
||||
- 建议使用训练集的部分图片作为校准集
|
||||
- 不要和测试集重复,避免过拟合
|
||||
Loading…
Reference in New Issue
Block a user