项目初始化;已经训练了三种yolo版本的2种尺寸的模型
This commit is contained in:
commit
45f1f3182f
48
.gitignore
vendored
Normal file
48
.gitignore
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Python cache and local tooling
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
|
||||||
|
# Virtualenvs and local dependency folders
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
.pydeps/
|
||||||
|
|
||||||
|
# Project-local tool state
|
||||||
|
.ultralytics/
|
||||||
|
|
||||||
|
# OS / editor noise
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
Desktop.ini
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
|
||||||
|
# Large datasets and generated training assets
|
||||||
|
datasets/
|
||||||
|
runs/
|
||||||
|
|
||||||
|
# Exported / downloaded model artifacts
|
||||||
|
pretrained/
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
*.onnx
|
||||||
|
*.rknn
|
||||||
|
*.engine
|
||||||
|
*.bin
|
||||||
|
|
||||||
|
# Logs and reports
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Keep placeholder files in otherwise ignored trees
|
||||||
|
!datasets/.gitkeep
|
||||||
|
!runs/.gitkeep
|
||||||
|
!pretrained/.gitkeep
|
||||||
411
01_download_dataset.py
Normal file
411
01_download_dataset.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
下载鞋子检测数据集。
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- Ultralytics Construction-PPE
|
||||||
|
- Open Images V7 (推荐用于单类 shoe 检测)
|
||||||
|
|
||||||
|
Open Images 推荐类别:
|
||||||
|
- Footwear
|
||||||
|
- Boot
|
||||||
|
|
||||||
|
可选补充:
|
||||||
|
- Sandal
|
||||||
|
|
||||||
|
不建议默认加入:
|
||||||
|
- High heels
|
||||||
|
- Roller skates
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import ssl
|
||||||
|
import sys
|
||||||
|
import urllib.request
|
||||||
|
import zipfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
OPENIMAGES_RECOMMENDED_CLASSES = ["Footwear", "Boot"]
|
||||||
|
OPENIMAGES_OPTIONAL_CLASSES = ["Sandal"]
|
||||||
|
OPENIMAGES_NOT_RECOMMENDED_CLASSES = ["High heels", "Roller skates"]
|
||||||
|
|
||||||
|
|
||||||
|
def download_ultralytics_cppe(dataset_dir: str = "datasets/construction-ppe"):
|
||||||
|
"""下载 Ultralytics Construction-PPE 数据集。"""
|
||||||
|
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/construction-ppe.zip"
|
||||||
|
zip_path = "construction-ppe.zip"
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("下载 Construction-PPE 数据集")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"来源: {url}")
|
||||||
|
print(f"目标: {dataset_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
print("[1/3] 下载中... (约 178MB)")
|
||||||
|
try:
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
with urllib.request.urlopen(url, context=ssl_context, timeout=300) as response:
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
downloaded = 0
|
||||||
|
chunk_size = 8192
|
||||||
|
|
||||||
|
with open(zip_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = response.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
downloaded += len(chunk)
|
||||||
|
if total_size > 0:
|
||||||
|
percent = (downloaded / total_size) * 100
|
||||||
|
print(
|
||||||
|
f"\r 进度: {percent:.1f}% ({downloaded}/{total_size} bytes)",
|
||||||
|
end="",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n ✓ 下载完成")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n ✗ 下载失败: {e}")
|
||||||
|
print("\n请手动下载:")
|
||||||
|
print(f" 1. 访问: {url}")
|
||||||
|
print(" 2. 下载 construction-ppe.zip")
|
||||||
|
print(f" 3. 解压到 {dataset_dir}/")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n[2/3] 解压中...")
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(dataset_dir)
|
||||||
|
print(" ✓ 解压完成")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ 解压失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print("\n[3/3] 清理临时文件...")
|
||||||
|
os.remove(zip_path)
|
||||||
|
print(" ✓ 完成")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_single_class_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
|
||||||
|
"""创建单类 shoe 检测配置文件。"""
|
||||||
|
yaml_content = f"""# 单类鞋子检测数据集配置
|
||||||
|
|
||||||
|
path: {dataset_path_value}
|
||||||
|
train: images/train
|
||||||
|
val: images/val
|
||||||
|
test: images/test
|
||||||
|
|
||||||
|
nc: 1
|
||||||
|
names: ['shoe']
|
||||||
|
|
||||||
|
dataset_info:
|
||||||
|
name: {source_name}
|
||||||
|
task: detect_shoe
|
||||||
|
note: 所有鞋类子类统一映射为单一类别 shoe
|
||||||
|
"""
|
||||||
|
|
||||||
|
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||||||
|
with open(yaml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
|
||||||
|
print(f"\n✓ 配置文件创建: {yaml_path}")
|
||||||
|
return yaml_path
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_yaml_for_existing_splits(dataset_dir: str, source_name: str):
|
||||||
|
"""根据现有目录结构重写 data.yaml。"""
|
||||||
|
images_root = Path(dataset_dir) / "images"
|
||||||
|
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
|
||||||
|
|
||||||
|
if not split_names:
|
||||||
|
raise RuntimeError(f"未找到任何图像 split: {images_root}")
|
||||||
|
|
||||||
|
train_split = "train" if "train" in split_names else split_names[0]
|
||||||
|
val_split = "val" if "val" in split_names else train_split
|
||||||
|
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
|
||||||
|
|
||||||
|
yaml_content = f"""# 单类鞋子检测数据集配置
|
||||||
|
|
||||||
|
path: .
|
||||||
|
train: images/{train_split}
|
||||||
|
val: images/{val_split}
|
||||||
|
{test_line}
|
||||||
|
|
||||||
|
nc: 1
|
||||||
|
names: ['shoe']
|
||||||
|
|
||||||
|
dataset_info:
|
||||||
|
name: {source_name}
|
||||||
|
task: detect_shoe
|
||||||
|
note: 所有鞋类子类统一映射为单一类别 shoe
|
||||||
|
"""
|
||||||
|
|
||||||
|
yaml_path = Path(dataset_dir) / "data.yaml"
|
||||||
|
yaml_path.write_text(yaml_content, encoding="utf-8")
|
||||||
|
print(f"\n✓ 配置文件更新: {yaml_path}")
|
||||||
|
return str(yaml_path)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_openimages_train_val_split(export_dir: str, train_ratio: float = 0.9, seed: int = 42):
|
||||||
|
"""如果导出结果只有单个 split,则自动切分为 train/val。"""
|
||||||
|
images_root = Path(export_dir) / "images"
|
||||||
|
labels_root = Path(export_dir) / "labels"
|
||||||
|
train_images = images_root / "train"
|
||||||
|
val_images = images_root / "val"
|
||||||
|
train_labels = labels_root / "train"
|
||||||
|
val_labels = labels_root / "val"
|
||||||
|
|
||||||
|
if train_images.exists() and train_labels.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
if not val_images.exists() or not val_labels.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
image_files = sorted([p for p in val_images.iterdir() if p.is_file()])
|
||||||
|
if len(image_files) < 2:
|
||||||
|
return
|
||||||
|
|
||||||
|
train_images.mkdir(parents=True, exist_ok=True)
|
||||||
|
train_labels.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
rng = random.Random(seed)
|
||||||
|
rng.shuffle(image_files)
|
||||||
|
|
||||||
|
split_idx = max(1, min(len(image_files) - 1, int(len(image_files) * train_ratio)))
|
||||||
|
train_files = image_files[:split_idx]
|
||||||
|
|
||||||
|
for image_path in train_files:
|
||||||
|
label_path = val_labels / f"{image_path.stem}.txt"
|
||||||
|
shutil.move(str(image_path), train_images / image_path.name)
|
||||||
|
if label_path.exists():
|
||||||
|
shutil.move(str(label_path), train_labels / label_path.name)
|
||||||
|
|
||||||
|
print("\n自动切分单一 split 为 train/val:")
|
||||||
|
print(f" train: {len(list(train_images.iterdir()))} 张")
|
||||||
|
print(f" val: {len(list(val_images.iterdir()))} 张")
|
||||||
|
|
||||||
|
|
||||||
|
def merge_openimages_to_single_class(export_dir: str):
|
||||||
|
"""将 Open Images 导出的多类别 YOLO 标签合并为单类 shoe。"""
|
||||||
|
labels_root = Path(export_dir) / "labels"
|
||||||
|
if not labels_root.exists():
|
||||||
|
print(f"✗ 未找到标签目录: {labels_root}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
total_files = 0
|
||||||
|
total_boxes = 0
|
||||||
|
changed_files = 0
|
||||||
|
|
||||||
|
for label_file in labels_root.rglob("*.txt"):
|
||||||
|
total_files += 1
|
||||||
|
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||||||
|
rewritten = []
|
||||||
|
file_changed = False
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) < 5:
|
||||||
|
continue
|
||||||
|
parts[0] = "0"
|
||||||
|
rewritten.append(" ".join(parts))
|
||||||
|
total_boxes += 1
|
||||||
|
file_changed = True
|
||||||
|
|
||||||
|
label_file.write_text("\n".join(rewritten) + ("\n" if rewritten else ""), encoding="utf-8")
|
||||||
|
|
||||||
|
if file_changed:
|
||||||
|
changed_files += 1
|
||||||
|
|
||||||
|
ensure_openimages_train_val_split(export_dir)
|
||||||
|
rewrite_yaml_for_existing_splits(export_dir, "Open Images V7")
|
||||||
|
|
||||||
|
print("\n单类合并完成:")
|
||||||
|
print(f" 标签文件: {changed_files}/{total_files}")
|
||||||
|
print(f" 标注框数: {total_boxes}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def download_openimages(classes: list, max_samples: int, dataset_dir: str):
|
||||||
|
"""通过 FiftyOne 下载 Open Images 并导出为单类 shoe 数据集。"""
|
||||||
|
try:
|
||||||
|
import fiftyone as fo
|
||||||
|
import fiftyone.zoo as foz
|
||||||
|
except ImportError:
|
||||||
|
print("错误: 未安装 fiftyone")
|
||||||
|
print("请运行: pip install fiftyone fiftyone-db-ubuntu2204")
|
||||||
|
return False
|
||||||
|
|
||||||
|
export_dir = dataset_dir + "-yolo"
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("下载 Open Images V7 数据集")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"类别: {classes}")
|
||||||
|
print(f"最大样本数: {max_samples}")
|
||||||
|
print("原始缓存目录: FiftyOne 默认缓存目录")
|
||||||
|
print(f"YOLO 导出目录: {export_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if os.path.exists(export_dir):
|
||||||
|
print(f"检测到已有导出目录,先删除: {export_dir}")
|
||||||
|
shutil.rmtree(export_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = foz.load_zoo_dataset(
|
||||||
|
"open-images-v7",
|
||||||
|
split="train",
|
||||||
|
label_types=["detections"],
|
||||||
|
classes=classes,
|
||||||
|
max_samples=max_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n导出为 YOLO 格式...")
|
||||||
|
dataset.export(
|
||||||
|
export_dir=export_dir,
|
||||||
|
dataset_type=fo.types.YOLOv5Dataset,
|
||||||
|
label_field="ground_truth",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not merge_openimages_to_single_class(export_dir):
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"\n✓ 数据集保存: {export_dir}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 下载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_dataset(dataset_dir: str):
|
||||||
|
"""检查数据集完整性。"""
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("检查数据集")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
required_dirs = [
|
||||||
|
"images/train",
|
||||||
|
"images/val",
|
||||||
|
"images/test",
|
||||||
|
"labels/train",
|
||||||
|
"labels/val",
|
||||||
|
"labels/test",
|
||||||
|
]
|
||||||
|
|
||||||
|
all_ok = True
|
||||||
|
for dir_name in required_dirs:
|
||||||
|
full_path = os.path.join(dataset_dir, dir_name)
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
count = len(
|
||||||
|
[f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))]
|
||||||
|
)
|
||||||
|
print(f" ✓ {dir_name}: {count} 个文件")
|
||||||
|
else:
|
||||||
|
print(f" ✗ {dir_name}: 不存在")
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
|
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||||||
|
if os.path.exists(yaml_path):
|
||||||
|
print(f" ✓ data.yaml: 已生成")
|
||||||
|
else:
|
||||||
|
print(f" ✗ data.yaml: 不存在")
|
||||||
|
all_ok = False
|
||||||
|
|
||||||
|
return all_ok
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="下载鞋子检测数据集",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
示例:
|
||||||
|
# 下载 Construction-PPE
|
||||||
|
python 01_download_dataset.py --source ultralytics
|
||||||
|
|
||||||
|
# 下载 Open Images 推荐鞋子类别
|
||||||
|
python 01_download_dataset.py --source openimages --max-samples 5000
|
||||||
|
|
||||||
|
# 如需补充凉鞋样本
|
||||||
|
python 01_download_dataset.py --source openimages --classes Footwear Boot Sandal --max-samples 8000
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
choices=["ultralytics", "openimages"],
|
||||||
|
default="openimages",
|
||||||
|
help="数据源 (默认: openimages)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dir",
|
||||||
|
default="datasets/openimages-shoes",
|
||||||
|
help="数据集保存目录",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-samples",
|
||||||
|
type=int,
|
||||||
|
default=5000,
|
||||||
|
help="Open Images 最大样本数 (默认: 5000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--classes",
|
||||||
|
nargs="+",
|
||||||
|
default=OPENIMAGES_RECOMMENDED_CLASSES,
|
||||||
|
help="Open Images 类别 (默认: Footwear Boot)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
success = False
|
||||||
|
final_dataset_dir = args.dir
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("鞋子检测数据集准备")
|
||||||
|
print("=" * 70)
|
||||||
|
if args.source == "openimages":
|
||||||
|
print(f"推荐类别: {OPENIMAGES_RECOMMENDED_CLASSES}")
|
||||||
|
print(f"可选补充: {OPENIMAGES_OPTIONAL_CLASSES}")
|
||||||
|
print(f"默认不建议: {OPENIMAGES_NOT_RECOMMENDED_CLASSES}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if args.source == "ultralytics":
|
||||||
|
success = download_ultralytics_cppe(args.dir)
|
||||||
|
if success:
|
||||||
|
create_single_class_yaml(args.dir, "Construction-PPE", dataset_path_value="construction-ppe")
|
||||||
|
check_dataset(args.dir)
|
||||||
|
|
||||||
|
elif args.source == "openimages":
|
||||||
|
success = download_openimages(args.classes, args.max_samples, args.dir)
|
||||||
|
final_dataset_dir = args.dir + "-yolo"
|
||||||
|
if success:
|
||||||
|
check_dataset(final_dataset_dir)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("数据集准备完成!")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"训练数据集路径: {final_dataset_dir}")
|
||||||
|
print("\n下一步:")
|
||||||
|
print(f" 1. 检查配置: {final_dataset_dir}/data.yaml")
|
||||||
|
print(f" 2. 开始训练: 02_train.bat")
|
||||||
|
print(
|
||||||
|
f" 3. 或手动: yolo detect train data={final_dataset_dir}/data.yaml model=yolov8s.pt epochs=150 imgsz=640"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print("\n✗ 数据集准备失败")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
90
02_train.bat
Normal file
90
02_train.bat
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
@echo off
|
||||||
|
chcp 65001 >nul
|
||||||
|
cls
|
||||||
|
|
||||||
|
:: 设置 Python 3.11 路径
|
||||||
|
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||||
|
|
||||||
|
echo ============================================================
|
||||||
|
echo 训练鞋子检测模型 (YOLOv8 + 640x640)
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
:: 设置数据集路径
|
||||||
|
set DATASET=datasets/openimages-shoes-yolo/data.yaml
|
||||||
|
|
||||||
|
:: 检查数据集是否存在
|
||||||
|
if not exist %DATASET% (
|
||||||
|
echo [错误] 找不到数据集配置文件: %DATASET%
|
||||||
|
echo.
|
||||||
|
echo 请先下载数据集:
|
||||||
|
echo python 01_download_dataset.py --source openimages --max-samples 5000
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo [信息] 数据集: %DATASET%
|
||||||
|
echo.
|
||||||
|
|
||||||
|
:: 选择模型
|
||||||
|
echo 选择模型:
|
||||||
|
echo 1. YOLOv8n (轻量级, 速度快)
|
||||||
|
echo 2. YOLOv8s (推荐, 速度和精度平衡)
|
||||||
|
echo 3. YOLOv8m (高精度, 较慢)
|
||||||
|
echo.
|
||||||
|
set /p MODEL_CHOICE="输入选择 (1-3, 默认 2): "
|
||||||
|
|
||||||
|
if "%MODEL_CHOICE%"=="" set MODEL_CHOICE=2
|
||||||
|
if "%MODEL_CHOICE%"=="1" (
|
||||||
|
set MODEL=yolov8n.pt
|
||||||
|
set DESC=YOLOv8n
|
||||||
|
)
|
||||||
|
if "%MODEL_CHOICE%"=="2" (
|
||||||
|
set MODEL=yolov8s.pt
|
||||||
|
set DESC=YOLOv8s (推荐)
|
||||||
|
)
|
||||||
|
if "%MODEL_CHOICE%"=="3" (
|
||||||
|
set MODEL=yolov8m.pt
|
||||||
|
set DESC=YOLOv8m
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo [信息] 使用模型: %DESC%
|
||||||
|
echo.
|
||||||
|
|
||||||
|
:: 训练参数
|
||||||
|
set EPOCHS=150
|
||||||
|
set IMGSZ=640
|
||||||
|
set BATCH=16
|
||||||
|
|
||||||
|
echo 训练参数:
|
||||||
|
echo - Epochs: %EPOCHS%
|
||||||
|
echo - Image Size: %IMGSZ%x%IMGSZ%
|
||||||
|
echo - Batch Size: %BATCH%
|
||||||
|
echo - Device: GPU (cuda:0)
|
||||||
|
echo.
|
||||||
|
|
||||||
|
echo ============================================================
|
||||||
|
echo 开始训练
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
yolo detect train data=%DATASET% model=%MODEL% epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||||
|
|
||||||
|
if %ERRORLEVEL% neq 0 (
|
||||||
|
echo.
|
||||||
|
echo [错误] 训练失败!
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
echo 训练完成!
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
echo 模型保存在: runs/detect/train/weights/best.pt
|
||||||
|
echo.
|
||||||
|
echo 下一步: 运行 03_export_onnx.bat 导出 ONNX
|
||||||
|
echo.
|
||||||
|
pause
|
||||||
38
03_export_onnx.bat
Normal file
38
03_export_onnx.bat
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
@echo off
|
||||||
|
chcp 65001 >nul
|
||||||
|
cls
|
||||||
|
|
||||||
|
:: 设置 Python 3.11 路径
|
||||||
|
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||||
|
|
||||||
|
echo ============================================================
|
||||||
|
echo 导出 ONNX 模型 (640x640)
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
set MODEL_PATH=%USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.pt
|
||||||
|
|
||||||
|
if not exist %MODEL_PATH% (
|
||||||
|
echo [错误] 找不到模型: %MODEL_PATH%
|
||||||
|
echo 请先运行 02_train.bat 训练
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo [信息] 输入模型: %MODEL_PATH%
|
||||||
|
echo.
|
||||||
|
|
||||||
|
yolo export model=%MODEL_PATH% format=onnx imgsz=640 opset=12 simplify
|
||||||
|
|
||||||
|
if %ERRORLEVEL% neq 0 (
|
||||||
|
echo [错误] 导出失败!
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo [成功] ONNX 模型: %USERPROFILE%\apps\ultralytics\runs\detect\train\weights\best.onnx
|
||||||
|
echo.
|
||||||
|
echo 下一步: 在 Ubuntu 上运行 04_convert_rknn.py 转换
|
||||||
|
echo.
|
||||||
|
pause
|
||||||
262
04_convert_rknn.py
Normal file
262
04_convert_rknn.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
将 YOLOv8 ONNX 模型转换为 RKNN 格式
|
||||||
|
适用于 RK3588 / RK3568 / RK3576 等平台
|
||||||
|
|
||||||
|
环境要求:
|
||||||
|
- Ubuntu x86_64 / Docker
|
||||||
|
- Python 3.8 / 3.9 / 3.10 / 3.11
|
||||||
|
- rknn-toolkit2 (pip install rknn-toolkit2==2.2.0)
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
# FP16 模式(推荐,速度快精度高)
|
||||||
|
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588
|
||||||
|
|
||||||
|
# INT8 量化(模型更小,需要校准数据集)
|
||||||
|
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn -t rk3588 -q -d dataset.txt
|
||||||
|
|
||||||
|
支持的 target_platform:
|
||||||
|
- rk3588 / rk3588s
|
||||||
|
- rk3568 / rk3566
|
||||||
|
- rk3576
|
||||||
|
- rv1106 / rv1103 / rv1103b
|
||||||
|
- rv1126
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def check_environment():
|
||||||
|
"""检查运行环境"""
|
||||||
|
try:
|
||||||
|
from rknn.api import RKNN
|
||||||
|
print("✓ RKNN Toolkit2 已安装")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("✗ 错误: 未安装 RKNN Toolkit2")
|
||||||
|
print("\n请安装:")
|
||||||
|
print(" pip install rknn-toolkit2==2.2.0")
|
||||||
|
print("\n或从源码安装:")
|
||||||
|
print(" https://github.com/airockchip/rknn-toolkit2")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_sample_dataset(onnx_path: str, output_path: str = "dataset.txt", num_samples: int = 20):
|
||||||
|
"""
|
||||||
|
创建示例量化校准数据集
|
||||||
|
用于 INT8 量化时提供校准图片路径
|
||||||
|
"""
|
||||||
|
print(f"\n创建示例校准数据集: {output_path}")
|
||||||
|
print("注意: 请用实际图片替换这些示例路径")
|
||||||
|
|
||||||
|
sample_content = f"""# RKNN INT8 量化校准数据集
|
||||||
|
# 每行一个图片路径,建议使用 20-100 张典型场景图片
|
||||||
|
# 图片格式: JPG, PNG, BMP 等
|
||||||
|
|
||||||
|
# 示例路径(请替换为实际路径):
|
||||||
|
# /path/to/train/images/img001.jpg
|
||||||
|
# /path/to/train/images/img002.jpg
|
||||||
|
# /path/to/valid/images/img001.jpg
|
||||||
|
|
||||||
|
# 提示:
|
||||||
|
# 1. 图片应与实际部署场景相似
|
||||||
|
# 2. 包含各种光照、角度、背景的样本
|
||||||
|
# 3. 建议 20-100 张,越多越慢但可能更准
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
f.write(sample_content)
|
||||||
|
|
||||||
|
print(f"✓ 示例数据集已创建: {output_path}")
|
||||||
|
print(" 请编辑此文件,添加实际的图片路径")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def convert_onnx_to_rknn(
|
||||||
|
onnx_path: str,
|
||||||
|
output_path: str = None,
|
||||||
|
target_platform: str = "rk3588",
|
||||||
|
do_quantization: bool = False,
|
||||||
|
dataset_path: str = None,
|
||||||
|
verbose: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
转换 ONNX 模型到 RKNN
|
||||||
|
|
||||||
|
Args:
|
||||||
|
onnx_path: ONNX 模型文件路径
|
||||||
|
output_path: 输出 RKNN 文件路径,默认与 ONNX 同名
|
||||||
|
target_platform: 目标平台,默认 rk3588
|
||||||
|
do_quantization: 是否启用 INT8 量化
|
||||||
|
dataset_path: 量化校准数据集路径(txt 文件,每行一张图片路径)
|
||||||
|
verbose: 是否打印详细信息
|
||||||
|
"""
|
||||||
|
if output_path is None:
|
||||||
|
output_path = onnx_path.replace(".onnx", ".rknn")
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
output_dir = os.path.dirname(output_path)
|
||||||
|
if output_dir and not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
print("="*70)
|
||||||
|
print(f"ONNX 转 RKNN")
|
||||||
|
print("="*70)
|
||||||
|
print(f"输入: {onnx_path}")
|
||||||
|
print(f"输出: {output_path}")
|
||||||
|
print(f"目标: {target_platform}")
|
||||||
|
print(f"量化: {'INT8' if do_quantization else 'FP16 (无量化)'}")
|
||||||
|
if do_quantization:
|
||||||
|
print(f"校准: {dataset_path}")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
# 检查输入文件
|
||||||
|
if not os.path.exists(onnx_path):
|
||||||
|
print(f"\n✗ 错误: 找不到 ONNX 文件: {onnx_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查数据集(如果需要量化)
|
||||||
|
if do_quantization:
|
||||||
|
if dataset_path is None:
|
||||||
|
print("\n✗ 错误: INT8 量化需要提供校准数据集")
|
||||||
|
print(" 使用 --dataset 指定数据集文件路径")
|
||||||
|
print(" 或运行 --create-dataset 创建示例")
|
||||||
|
return False
|
||||||
|
if not os.path.exists(dataset_path):
|
||||||
|
print(f"\n✗ 错误: 找不到数据集文件: {dataset_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
from rknn.api import RKNN
|
||||||
|
|
||||||
|
# 创建 RKNN 对象
|
||||||
|
rknn = RKNN(verbose=verbose)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 配置模型
|
||||||
|
print("\n[1/4] 配置模型...")
|
||||||
|
rknn.config(
|
||||||
|
mean_values=[[0, 0, 0]], # YOLOv8 使用 0-255 输入
|
||||||
|
std_values=[[255, 255, 255]], # 归一化到 0-1
|
||||||
|
target_platform=target_platform
|
||||||
|
)
|
||||||
|
print(" ✓ 完成")
|
||||||
|
|
||||||
|
# 加载 ONNX
|
||||||
|
print("\n[2/4] 加载 ONNX 模型...")
|
||||||
|
ret = rknn.load_onnx(model=onnx_path)
|
||||||
|
if ret != 0:
|
||||||
|
print(" ✗ 加载失败!")
|
||||||
|
return False
|
||||||
|
print(" ✓ 完成")
|
||||||
|
|
||||||
|
# 构建模型
|
||||||
|
print("\n[3/4] 构建 RKNN 模型...")
|
||||||
|
if do_quantization:
|
||||||
|
print(f" 使用 INT8 量化,校准数据集: {dataset_path}")
|
||||||
|
ret = rknn.build(do_quantization=True, dataset=dataset_path)
|
||||||
|
else:
|
||||||
|
print(" 使用 FP16 模式(无量化)")
|
||||||
|
ret = rknn.build(do_quantization=False)
|
||||||
|
|
||||||
|
if ret != 0:
|
||||||
|
print(" ✗ 构建失败!")
|
||||||
|
return False
|
||||||
|
print(" ✓ 完成")
|
||||||
|
|
||||||
|
# 导出 RKNN
|
||||||
|
print("\n[4/4] 导出 RKNN 模型...")
|
||||||
|
ret = rknn.export_rknn(output_path)
|
||||||
|
if ret != 0:
|
||||||
|
print(" ✗ 导出失败!")
|
||||||
|
return False
|
||||||
|
print(" ✓ 完成")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
rknn.release()
|
||||||
|
|
||||||
|
# 验证输出
|
||||||
|
if os.path.exists(output_path):
|
||||||
|
size_mb = os.path.getsize(output_path) / (1024 * 1024)
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print(f"✓ 转换成功!")
|
||||||
|
print(f" 输出文件: {output_path}")
|
||||||
|
print(f" 文件大小: {size_mb:.2f} MB")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
print("\n下一步:")
|
||||||
|
print(f" 1. 复制到 RK3588:")
|
||||||
|
print(f" scp {output_path} orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/")
|
||||||
|
print(f" 2. 更新配置文件中的模型路径")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("\n✗ 错误: 输出文件未生成")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="将 YOLOv8 ONNX 模型转换为 RKNN",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog="""
|
||||||
|
示例:
|
||||||
|
# FP16 模式(推荐)
|
||||||
|
python 04_convert_rknn.py best.onnx -o shoe_detector.rknn
|
||||||
|
|
||||||
|
# 指定目标平台
|
||||||
|
python 04_convert_rknn.py best.onnx -t rk3568
|
||||||
|
|
||||||
|
# INT8 量化
|
||||||
|
python 04_convert_rknn.py best.onnx -q -d dataset.txt
|
||||||
|
|
||||||
|
# 创建示例校准数据集
|
||||||
|
python 04_convert_rknn.py --create-dataset
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("onnx", nargs="?", help="ONNX 模型文件路径")
|
||||||
|
parser.add_argument("-o", "--output", help="输出 RKNN 文件路径")
|
||||||
|
parser.add_argument("-t", "--target", default="rk3588",
|
||||||
|
help="目标平台 (默认: rk3588)")
|
||||||
|
parser.add_argument("-q", "--quantize", action="store_true",
|
||||||
|
help="启用 INT8 量化")
|
||||||
|
parser.add_argument("-d", "--dataset", help="量化校准数据集路径 (txt 文件)")
|
||||||
|
parser.add_argument("--create-dataset", action="store_true",
|
||||||
|
help="创建示例校准数据集并退出")
|
||||||
|
parser.add_argument("-v", "--verbose", action="store_true", default=True,
|
||||||
|
help="显示详细信息")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 创建示例数据集
|
||||||
|
if args.create_dataset:
|
||||||
|
create_sample_dataset("best.onnx", "dataset.txt")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 检查参数
|
||||||
|
if args.onnx is None:
|
||||||
|
parser.print_help()
|
||||||
|
print("\n错误: 请提供 ONNX 文件路径")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# 检查环境
|
||||||
|
if not check_environment():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# 执行转换
|
||||||
|
success = convert_onnx_to_rknn(
|
||||||
|
onnx_path=args.onnx,
|
||||||
|
output_path=args.output,
|
||||||
|
target_platform=args.target,
|
||||||
|
do_quantization=args.quantize,
|
||||||
|
dataset_path=args.dataset,
|
||||||
|
verbose=args.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
return 0 if success else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
119
05_prepare_ppe_shoe_subset.py
Normal file
119
05_prepare_ppe_shoe_subset.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""从 Construction-PPE 中提取单类 shoe 子集。"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
SHOE_CLASSES = {"3", "10"} # boots, no_boots
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_clean_dir(path: Path):
|
||||||
|
if path.exists():
|
||||||
|
shutil.rmtree(path)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def write_yaml(output_dir: Path):
|
||||||
|
yaml_path = output_dir / "data.yaml"
|
||||||
|
abs_output = output_dir.resolve().as_posix()
|
||||||
|
yaml_path.write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"# PPE 鞋子单类微调数据集配置",
|
||||||
|
"",
|
||||||
|
f"path: {abs_output}",
|
||||||
|
"train: images/train",
|
||||||
|
"val: images/val",
|
||||||
|
"test: images/test",
|
||||||
|
"",
|
||||||
|
"nc: 1",
|
||||||
|
"names: ['shoe']",
|
||||||
|
"",
|
||||||
|
"dataset_info:",
|
||||||
|
" name: Construction-PPE shoe subset",
|
||||||
|
" source: Construction-PPE",
|
||||||
|
" note: 从 boots 和 no_boots 提取并统一映射为 shoe",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_split(source_dir: Path, output_dir: Path, split: str):
|
||||||
|
image_src = source_dir / "images" / split
|
||||||
|
label_src = source_dir / "labels" / split
|
||||||
|
image_dst = output_dir / "images" / split
|
||||||
|
label_dst = output_dir / "labels" / split
|
||||||
|
|
||||||
|
image_dst.mkdir(parents=True, exist_ok=True)
|
||||||
|
label_dst.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
kept_images = 0
|
||||||
|
kept_boxes = 0
|
||||||
|
|
||||||
|
for label_file in sorted(label_src.glob("*.txt")):
|
||||||
|
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||||||
|
shoe_lines = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
parts = line.strip().split()
|
||||||
|
if len(parts) < 5 or parts[0] not in SHOE_CLASSES:
|
||||||
|
continue
|
||||||
|
parts[0] = "0"
|
||||||
|
shoe_lines.append(" ".join(parts))
|
||||||
|
|
||||||
|
if not shoe_lines:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_file = image_src / f"{label_file.stem}.jpg"
|
||||||
|
if not image_file.exists():
|
||||||
|
image_file = image_src / f"{label_file.stem}.png"
|
||||||
|
if not image_file.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
shutil.copy2(image_file, image_dst / image_file.name)
|
||||||
|
(label_dst / label_file.name).write_text("\n".join(shoe_lines) + "\n", encoding="utf-8")
|
||||||
|
kept_images += 1
|
||||||
|
kept_boxes += len(shoe_lines)
|
||||||
|
|
||||||
|
return kept_images, kept_boxes
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="提取 PPE 鞋子单类子集")
|
||||||
|
parser.add_argument(
|
||||||
|
"--source",
|
||||||
|
default="datasets/construction-ppe",
|
||||||
|
help="Construction-PPE 数据集目录",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
default="datasets/ppe-shoes",
|
||||||
|
help="输出目录",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
source_dir = Path(args.source)
|
||||||
|
output_dir = Path(args.output)
|
||||||
|
|
||||||
|
ensure_clean_dir(output_dir)
|
||||||
|
|
||||||
|
total_images = 0
|
||||||
|
total_boxes = 0
|
||||||
|
for split in ("train", "val", "test"):
|
||||||
|
kept_images, kept_boxes = convert_split(source_dir, output_dir, split)
|
||||||
|
total_images += kept_images
|
||||||
|
total_boxes += kept_boxes
|
||||||
|
print(f"[{split}] images={kept_images} boxes={kept_boxes}")
|
||||||
|
|
||||||
|
write_yaml(output_dir)
|
||||||
|
print(f"\n输出目录: {output_dir}")
|
||||||
|
print(f"总图片数: {total_images}")
|
||||||
|
print(f"总框数: {total_boxes}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
62
06_finetune_ppe.bat
Normal file
62
06_finetune_ppe.bat
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
@echo off
|
||||||
|
setlocal
|
||||||
|
|
||||||
|
set "PATH=C:\Users\Tellme\AppData\Local\Programs\Python\Python311\Scripts;C:\Users\Tellme\AppData\Local\Programs\Python\Python311;%PATH%"
|
||||||
|
|
||||||
|
echo ============================================================
|
||||||
|
echo Stage 2 Fine-tuning on PPE shoe subset
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
set "DATASET=C:\Users\Tellme\apps\ppe-model-training\datasets\ppe-shoes\data.yaml"
|
||||||
|
set "BASE_MODEL=C:\Users\Tellme\apps\ultralytics\runs\detect\train3\weights\best.pt"
|
||||||
|
|
||||||
|
if not exist "%DATASET%" (
|
||||||
|
echo [ERROR] PPE shoe subset not found: %DATASET%
|
||||||
|
echo Run:
|
||||||
|
echo C:\Users\Tellme\AppData\Local\Programs\Python\Python311\python.exe C:\Users\Tellme\apps\ppe-model-training\05_prepare_ppe_shoe_subset.py
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not exist "%BASE_MODEL%" (
|
||||||
|
echo [ERROR] Base model not found: %BASE_MODEL%
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo [INFO] Dataset: %DATASET%
|
||||||
|
echo [INFO] Base model: %BASE_MODEL%
|
||||||
|
echo.
|
||||||
|
|
||||||
|
set "EPOCHS=50"
|
||||||
|
set "IMGSZ=640"
|
||||||
|
set "BATCH=16"
|
||||||
|
|
||||||
|
echo Fine-tune params:
|
||||||
|
echo Epochs: %EPOCHS%
|
||||||
|
echo Image Size: %IMGSZ%x%IMGSZ%
|
||||||
|
echo Batch Size: %BATCH%
|
||||||
|
echo Device: GPU (cuda:0)
|
||||||
|
echo.
|
||||||
|
|
||||||
|
echo ============================================================
|
||||||
|
echo Start fine-tuning
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
yolo detect train data="%DATASET%" model="%BASE_MODEL%" epochs=%EPOCHS% imgsz=%IMGSZ% batch=%BATCH% device=0
|
||||||
|
|
||||||
|
if %ERRORLEVEL% neq 0 (
|
||||||
|
echo.
|
||||||
|
echo [ERROR] Fine-tuning failed.
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
echo Fine-tuning complete
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
pause
|
||||||
145
07_build_public_shoe_dataset.py
Normal file
145
07_build_public_shoe_dataset.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Build a merged single-class public shoe dataset for YOLO training."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SOURCES = [
|
||||||
|
"datasets/openimages-shoes-yolo",
|
||||||
|
"datasets/ppe-shoes",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Merge public shoe datasets into one YOLO dataset")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sources",
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_SOURCES,
|
||||||
|
help="Source dataset directories containing images/<split> and labels/<split>",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
default="datasets/shoe-public-mix",
|
||||||
|
help="Output merged dataset directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clean",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete the output directory before rebuilding",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_output_layout(output_dir: Path) -> None:
|
||||||
|
for split in ("train", "val", "test"):
|
||||||
|
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||||
|
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_split(source_dir: Path, output_dir: Path, split: str) -> tuple[int, int]:
|
||||||
|
image_dir = source_dir / "images" / split
|
||||||
|
label_dir = source_dir / "labels" / split
|
||||||
|
if not image_dir.exists() or not label_dir.exists():
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
images_copied = 0
|
||||||
|
boxes_copied = 0
|
||||||
|
prefix = source_dir.name.replace("-", "_")
|
||||||
|
|
||||||
|
for label_file in sorted(label_dir.glob("*.txt")):
|
||||||
|
image_file = None
|
||||||
|
for ext in (".jpg", ".jpeg", ".png", ".bmp", ".webp"):
|
||||||
|
candidate = image_dir / f"{label_file.stem}{ext}"
|
||||||
|
if candidate.exists():
|
||||||
|
image_file = candidate
|
||||||
|
break
|
||||||
|
|
||||||
|
if image_file is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines = [line.strip() for line in label_file.read_text(encoding="utf-8").splitlines() if line.strip()]
|
||||||
|
if not lines:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_stem = f"{prefix}_{label_file.stem}"
|
||||||
|
dst_image = output_dir / "images" / split / f"{out_stem}{image_file.suffix.lower()}"
|
||||||
|
dst_label = output_dir / "labels" / split / f"{out_stem}.txt"
|
||||||
|
|
||||||
|
shutil.copy2(image_file, dst_image)
|
||||||
|
dst_label.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
images_copied += 1
|
||||||
|
boxes_copied += len(lines)
|
||||||
|
|
||||||
|
return images_copied, boxes_copied
|
||||||
|
|
||||||
|
|
||||||
|
def write_yaml(output_dir: Path) -> None:
|
||||||
|
yaml_path = output_dir / "data.yaml"
|
||||||
|
yaml_path.write_text(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
"# Public shoe training mix",
|
||||||
|
"",
|
||||||
|
f"path: {output_dir.resolve().as_posix()}",
|
||||||
|
"train: images/train",
|
||||||
|
"val: images/val",
|
||||||
|
"test: images/test",
|
||||||
|
"",
|
||||||
|
"nc: 1",
|
||||||
|
"names: ['shoe']",
|
||||||
|
"",
|
||||||
|
"dataset_info:",
|
||||||
|
" name: shoe-public-mix",
|
||||||
|
" task: detect_shoe",
|
||||||
|
" note: merged Open Images shoe data and PPE shoe subset",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
output_dir = Path(args.output)
|
||||||
|
|
||||||
|
if args.clean and output_dir.exists():
|
||||||
|
shutil.rmtree(output_dir)
|
||||||
|
|
||||||
|
ensure_output_layout(output_dir)
|
||||||
|
|
||||||
|
summary: dict[str, dict[str, tuple[int, int]]] = defaultdict(dict)
|
||||||
|
for source in args.sources:
|
||||||
|
source_dir = Path(source)
|
||||||
|
if not source_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Source dataset not found: {source_dir}")
|
||||||
|
|
||||||
|
for split in ("train", "val", "test"):
|
||||||
|
summary[source_dir.name][split] = copy_split(source_dir, output_dir, split)
|
||||||
|
|
||||||
|
write_yaml(output_dir)
|
||||||
|
|
||||||
|
print(f"Output dataset: {output_dir.resolve()}")
|
||||||
|
total_images = 0
|
||||||
|
total_boxes = 0
|
||||||
|
for source_name, split_map in summary.items():
|
||||||
|
print(f"[{source_name}]")
|
||||||
|
for split in ("train", "val", "test"):
|
||||||
|
images_copied, boxes_copied = split_map.get(split, (0, 0))
|
||||||
|
total_images += images_copied
|
||||||
|
total_boxes += boxes_copied
|
||||||
|
print(f" {split}: images={images_copied} boxes={boxes_copied}")
|
||||||
|
|
||||||
|
print(f"Total images: {total_images}")
|
||||||
|
print(f"Total boxes: {total_boxes}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
155
08_train_compare_models.py
Normal file
155
08_train_compare_models.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Train multiple YOLO models on the merged shoe dataset and compare ROI performance."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parent
|
||||||
|
PYDEPS = REPO_ROOT / ".pydeps"
|
||||||
|
if PYDEPS.exists():
|
||||||
|
sys.path.insert(0, str(PYDEPS))
|
||||||
|
|
||||||
|
from ultralytics import YOLO # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["yolov8s.pt", "yolo11s.pt", "yolo26s.pt"]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Train and compare YOLO shoe detectors")
|
||||||
|
parser.add_argument("--data", default="datasets/shoe-public-mix/data.yaml", help="Training dataset yaml")
|
||||||
|
parser.add_argument("--roi-dir", default="datasets/roi-shoes", help="Real-world ROI evaluation directory")
|
||||||
|
parser.add_argument("--project", default="runs/shoe_compare", help="Ultralytics project directory")
|
||||||
|
parser.add_argument("--models", nargs="+", default=DEFAULT_MODELS, help="Model checkpoints or aliases")
|
||||||
|
parser.add_argument("--epochs", type=int, default=40, help="Training epochs for each model")
|
||||||
|
parser.add_argument("--imgsz", type=int, default=640, help="Training image size")
|
||||||
|
parser.add_argument("--batch", type=int, default=16, help="Training batch size")
|
||||||
|
parser.add_argument("--workers", type=int, default=8, help="Data loader workers")
|
||||||
|
parser.add_argument("--device", default="0", help="Training device, e.g. 0 or cpu")
|
||||||
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||||
|
parser.add_argument("--patience", type=int, default=20, help="Early stopping patience")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def safe_name(model_name: str) -> str:
|
||||||
|
stem = Path(model_name).stem
|
||||||
|
return stem.replace(".", "_").replace("-", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_roi(model_path: Path, roi_dir: Path, save_dir: Path, device: str) -> dict:
|
||||||
|
model = YOLO(str(model_path))
|
||||||
|
results = model.predict(
|
||||||
|
source=str(roi_dir),
|
||||||
|
conf=0.1,
|
||||||
|
save=True,
|
||||||
|
project=str(save_dir.parent),
|
||||||
|
name=save_dir.name,
|
||||||
|
exist_ok=True,
|
||||||
|
verbose=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_image = []
|
||||||
|
max_confs = []
|
||||||
|
hit_count = 0
|
||||||
|
for result in results:
|
||||||
|
n = 0 if result.boxes is None else len(result.boxes)
|
||||||
|
hit = n > 0
|
||||||
|
if hit:
|
||||||
|
hit_count += 1
|
||||||
|
max_confs.append(float(result.boxes.conf.max().item()))
|
||||||
|
per_image.append(
|
||||||
|
{
|
||||||
|
"image": Path(result.path).name,
|
||||||
|
"detections": n,
|
||||||
|
"max_conf": float(result.boxes.conf.max().item()) if hit else 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"roi_total": len(per_image),
|
||||||
|
"roi_hits": hit_count,
|
||||||
|
"roi_hit_rate": hit_count / len(per_image) if per_image else 0.0,
|
||||||
|
"roi_mean_max_conf": mean(max_confs) if max_confs else 0.0,
|
||||||
|
"per_image": per_image,
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
os.environ.setdefault("YOLO_CONFIG_DIR", str(REPO_ROOT / ".ultralytics"))
|
||||||
|
|
||||||
|
data = Path(args.data)
|
||||||
|
roi_dir = Path(args.roi_dir)
|
||||||
|
project = Path(args.project)
|
||||||
|
project.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if not data.exists():
|
||||||
|
raise FileNotFoundError(f"Dataset yaml not found: {data}")
|
||||||
|
if not roi_dir.exists():
|
||||||
|
raise FileNotFoundError(f"ROI directory not found: {roi_dir}")
|
||||||
|
|
||||||
|
compare_summary = []
|
||||||
|
for model_name in args.models:
|
||||||
|
run_name = f"{safe_name(model_name)}_shoe_{args.imgsz}"
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Training {model_name} -> {run_name}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
model = YOLO(model_name)
|
||||||
|
train_results = model.train(
|
||||||
|
data=str(data.resolve()),
|
||||||
|
epochs=args.epochs,
|
||||||
|
imgsz=args.imgsz,
|
||||||
|
batch=args.batch,
|
||||||
|
workers=args.workers,
|
||||||
|
device=args.device,
|
||||||
|
seed=args.seed,
|
||||||
|
patience=args.patience,
|
||||||
|
project=str(project.resolve()),
|
||||||
|
name=run_name,
|
||||||
|
exist_ok=True,
|
||||||
|
pretrained=True,
|
||||||
|
cos_lr=True,
|
||||||
|
amp=True,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = Path(train_results.save_dir) / "weights" / "best.pt"
|
||||||
|
roi_save_dir = Path(train_results.save_dir) / "roi_eval"
|
||||||
|
roi_summary = evaluate_roi(best_path, roi_dir, roi_save_dir, args.device)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"model": model_name,
|
||||||
|
"run_name": run_name,
|
||||||
|
"save_dir": str(Path(train_results.save_dir).resolve()),
|
||||||
|
"best_pt": str(best_path.resolve()),
|
||||||
|
"metrics": {
|
||||||
|
"map50": float(train_results.results_dict.get("metrics/mAP50(B)", 0.0)),
|
||||||
|
"map50_95": float(train_results.results_dict.get("metrics/mAP50-95(B)", 0.0)),
|
||||||
|
"precision": float(train_results.results_dict.get("metrics/precision(B)", 0.0)),
|
||||||
|
"recall": float(train_results.results_dict.get("metrics/recall(B)", 0.0)),
|
||||||
|
},
|
||||||
|
"roi_eval": roi_summary,
|
||||||
|
}
|
||||||
|
compare_summary.append(record)
|
||||||
|
|
||||||
|
summary_path = Path(train_results.save_dir) / "roi_eval_summary.json"
|
||||||
|
summary_path.write_text(json.dumps(record, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
|
compare_json = project / "compare_summary.json"
|
||||||
|
compare_json.write_text(json.dumps(compare_summary, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
|
||||||
|
print(json.dumps(record, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
19
10_run_shoe_compare.ps1
Normal file
19
10_run_shoe_compare.ps1
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
|
||||||
|
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||||
|
$env:PYTHONPATH = "$repo\.pydeps"
|
||||||
|
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||||
|
|
||||||
|
Set-Location $repo
|
||||||
|
|
||||||
|
py -3.11 .\07_build_public_shoe_dataset.py --clean
|
||||||
|
|
||||||
|
py -3.11 .\08_train_compare_models.py `
|
||||||
|
--data "$repo\datasets\shoe-public-mix\data.yaml" `
|
||||||
|
--roi-dir "$repo\datasets\roi-shoes" `
|
||||||
|
--project "$repo\runs\shoe_compare" `
|
||||||
|
--models "yolov8s.pt" "yolo11s.pt" "yolo26s.pt" `
|
||||||
|
--epochs 40 `
|
||||||
|
--imgsz 640 `
|
||||||
|
--batch 16 `
|
||||||
|
--device 0
|
||||||
17
11_run_shoe_compare_960.ps1
Normal file
17
11_run_shoe_compare_960.ps1
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
|
||||||
|
$repo = "C:\Users\tianj\Documents\apps\DetectionModelTraining"
|
||||||
|
$env:PYTHONPATH = "$repo\.pydeps"
|
||||||
|
$env:YOLO_CONFIG_DIR = "$repo\.ultralytics"
|
||||||
|
|
||||||
|
Set-Location $repo
|
||||||
|
|
||||||
|
py -3.11 .\08_train_compare_models.py `
|
||||||
|
--data "$repo\datasets\shoe-public-mix\data.yaml" `
|
||||||
|
--roi-dir "$repo\datasets\roi-shoes" `
|
||||||
|
--project "$repo\runs\shoe_compare_960" `
|
||||||
|
--models "yolov8s.pt" "yolo11s.pt" "yolo26s.pt" `
|
||||||
|
--epochs 40 `
|
||||||
|
--imgsz 960 `
|
||||||
|
--batch 8 `
|
||||||
|
--device 0
|
||||||
156
README.md
Normal file
156
README.md
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
# 鞋子检测模型训练指南
|
||||||
|
|
||||||
|
## 方案:640x640 单模型(部署时用2窗口)
|
||||||
|
|
||||||
|
**训练阶段**:
|
||||||
|
- 输入:640x640 完整图片
|
||||||
|
- 模型:YOLOv8s
|
||||||
|
- 输出:640x640 模型文件
|
||||||
|
|
||||||
|
**部署阶段**(pipeline配置):
|
||||||
|
- 原图 1920x1080
|
||||||
|
- 分成 2 个 960x1080 窗口
|
||||||
|
- 每个窗口 resize 到 640x640 送入模型
|
||||||
|
- 合并检测结果
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
train/
|
||||||
|
├── README.md # 本文件
|
||||||
|
├── 01_download_dataset.py # 下载鞋子数据集(推荐 Open Images)
|
||||||
|
├── 02_train.bat # Windows 一键训练脚本
|
||||||
|
├── 03_export_onnx.bat # 导出 ONNX 脚本
|
||||||
|
├── 04_convert_rknn.py # 转换为 RKNN 脚本
|
||||||
|
├── 05_prepare_ppe_shoe_subset.py # 提取 PPE 鞋子单类子集
|
||||||
|
├── 06_finetune_ppe.bat # 用 PPE 鞋子子集做二阶段微调
|
||||||
|
├── data.yaml.template # 数据集配置文件
|
||||||
|
└── samples/ # 示例图片
|
||||||
|
├── calibration/
|
||||||
|
├── test_images/
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 下载数据集
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd train
|
||||||
|
python 01_download_dataset.py --source openimages --max-samples 5000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 准备配置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
脚本会自动生成 datasets/openimages-shoes-yolo/data.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 训练(640x640)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
02_train.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
或手动:
|
||||||
|
```bash
|
||||||
|
yolo detect train \
|
||||||
|
data=datasets/openimages-shoes-yolo/data.yaml \
|
||||||
|
model=yolov8s.pt \
|
||||||
|
epochs=150 \
|
||||||
|
imgsz=640 \
|
||||||
|
batch=16 \
|
||||||
|
device=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**训练参数**:
|
||||||
|
- 模型:YOLOv8s(速度和精度平衡)
|
||||||
|
- 输入:640x640
|
||||||
|
- 预计时间:30-60分钟
|
||||||
|
|
||||||
|
### 4. 导出 ONNX
|
||||||
|
|
||||||
|
```bash
|
||||||
|
03_export_onnx.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 转换为 RKNN
|
||||||
|
|
||||||
|
在 Ubuntu PC 上:
|
||||||
|
```bash
|
||||||
|
python 04_convert_rknn.py runs/detect/train/weights/best.onnx -o shoe_detector_640.rknn -t rk3588
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. 部署(2窗口配置)
|
||||||
|
|
||||||
|
复制到 RK3588:
|
||||||
|
```bash
|
||||||
|
scp shoe_detector_640.rknn orangepi@<rk3588_ip>:/home/orangepi/apps/OrangePi3588Media/models/
|
||||||
|
```
|
||||||
|
|
||||||
|
Pipeline 配置(部署阶段用2窗口):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "pre_shoe",
|
||||||
|
"type": "preprocess",
|
||||||
|
"windows": [
|
||||||
|
{"x": 0, "y": 0, "w": 960, "h": 1080},
|
||||||
|
{"x": 960, "y": 0, "w": 960, "h": 1080}
|
||||||
|
],
|
||||||
|
"dst_w": 640,
|
||||||
|
"dst_h": 640
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7. 方案 A:PPE 二阶段微调
|
||||||
|
|
||||||
|
当 Open Images 基础模型训练完成后,可继续用 PPE 鞋子子集做场景微调:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python 05_prepare_ppe_shoe_subset.py
|
||||||
|
06_finetune_ppe.bat
|
||||||
|
```
|
||||||
|
|
||||||
|
PPE 鞋子子集来源:
|
||||||
|
- `boots`
|
||||||
|
- `no_boots`
|
||||||
|
|
||||||
|
这两个类会统一映射成单类:
|
||||||
|
- `shoe`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 类别说明(Open Images)
|
||||||
|
|
||||||
|
Open Images 官方鞋类层级中,`Footwear` 的子类包括:
|
||||||
|
- `Boot`
|
||||||
|
- `Sandal`
|
||||||
|
- `High heels`
|
||||||
|
- `Roller skates`
|
||||||
|
|
||||||
|
本项目推荐下载:
|
||||||
|
- `Footwear`
|
||||||
|
- `Boot`
|
||||||
|
|
||||||
|
可选补充:
|
||||||
|
- `Sandal`
|
||||||
|
|
||||||
|
不建议默认加入:
|
||||||
|
- `High heels`
|
||||||
|
- `Roller skates`
|
||||||
|
|
||||||
|
训练时统一映射为单一类别:
|
||||||
|
- `0: shoe`
|
||||||
|
|
||||||
|
这样模型目标更聚焦,先尽量把鞋子稳定检出,再在后处理里判断是否为黑色鞋。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 相关链接
|
||||||
|
|
||||||
|
- [Open Images 数据集](https://storage.googleapis.com/openimages/web/index.html)
|
||||||
|
- [Ultralytics YOLOv8](https://docs.ultralytics.com/)
|
||||||
20
data.yaml.template
Normal file
20
data.yaml.template
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# 单类鞋子检测数据集配置
|
||||||
|
|
||||||
|
path: .
|
||||||
|
train: images/train
|
||||||
|
val: images/val
|
||||||
|
test: images/test
|
||||||
|
|
||||||
|
# 单类 shoe
|
||||||
|
nc: 1
|
||||||
|
names: ['shoe']
|
||||||
|
|
||||||
|
# 推荐来源:
|
||||||
|
# python 01_download_dataset.py --source openimages --max-samples 5000
|
||||||
|
#
|
||||||
|
# 推荐 Open Images 子类:
|
||||||
|
# - Footwear
|
||||||
|
# - Boot
|
||||||
|
#
|
||||||
|
# 可选补充:
|
||||||
|
# - Sandal
|
||||||
46
samples/README.md
Normal file
46
samples/README.md
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# 示例图片目录
|
||||||
|
|
||||||
|
用于存放测试图片和量化校准样本。
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
samples/
|
||||||
|
├── test_images/ # 用于测试模型的示例图片
|
||||||
|
├── calibration/ # INT8 量化校准用的图片(约 20-100 张)
|
||||||
|
└── README.md # 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
### 测试图片 (test_images/)
|
||||||
|
|
||||||
|
存放一些典型场景的鞋子图片,用于验证模型效果。
|
||||||
|
|
||||||
|
### 校准图片 (calibration/)
|
||||||
|
|
||||||
|
INT8 量化时需要,用于确定量化参数。
|
||||||
|
|
||||||
|
**要求:**
|
||||||
|
- 应与实际部署场景相似
|
||||||
|
- 包含各种光照、角度、背景的样本
|
||||||
|
- 建议 20-100 张
|
||||||
|
- 图片格式: JPG, PNG, BMP
|
||||||
|
|
||||||
|
**创建校准数据集文件:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Linux/macOS
|
||||||
|
ls samples/calibration/*.jpg > dataset.txt
|
||||||
|
ls samples/calibration/*.png >> dataset.txt
|
||||||
|
|
||||||
|
# Windows CMD
|
||||||
|
dir /b samples\calibration\*.jpg > dataset.txt
|
||||||
|
dir /b samples\calibration\*.png >> dataset.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
- 校准图片越多,转换时间越长,但精度可能更高
|
||||||
|
- 建议使用训练集的部分图片作为校准集
|
||||||
|
- 不要和测试集重复,避免过拟合
|
||||||
Loading…
Reference in New Issue
Block a user