DetectionModelTraining/01_download_dataset.py

589 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
下载鞋子检测数据集。
支持:
- Ultralytics Construction-PPE
- Open Images V7 (推荐用于单类 shoe 检测)
Open Images 推荐类别:
- Footwear
- Boot
可选补充:
- Sandal
不建议默认加入:
- High heels
- Roller skates
"""
import argparse
import ast
import os
import random
import shutil
import ssl
import sys
import urllib.request
import zipfile
from pathlib import Path
OPENIMAGES_RECOMMENDED_CLASSES = ["Footwear", "Boot"]
OPENIMAGES_OPTIONAL_CLASSES = ["Sandal"]
OPENIMAGES_NOT_RECOMMENDED_CLASSES = ["High heels", "Roller skates"]
OPENIMAGES_PERSON_CLASS = "Person"
ROI_SOURCE_DEFAULT_DIR = "datasets/openimages-person-shoes"
def download_ultralytics_cppe(dataset_dir: str = "datasets/construction-ppe"):
"""下载 Ultralytics Construction-PPE 数据集。"""
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/construction-ppe.zip"
zip_path = "construction-ppe.zip"
print("=" * 70)
print("下载 Construction-PPE 数据集")
print("=" * 70)
print(f"来源: {url}")
print(f"目标: {dataset_dir}")
print()
os.makedirs(dataset_dir, exist_ok=True)
print("[1/3] 下载中... (约 178MB)")
try:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
with urllib.request.urlopen(url, context=ssl_context, timeout=300) as response:
total_size = int(response.headers.get("content-length", 0))
downloaded = 0
chunk_size = 8192
with open(zip_path, "wb") as f:
while True:
chunk = response.read(chunk_size)
if not chunk:
break
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
print(
f"\r 进度: {percent:.1f}% ({downloaded}/{total_size} bytes)",
end="",
)
print("\n ✓ 下载完成")
except Exception as e:
print(f"\n ✗ 下载失败: {e}")
print("\n请手动下载:")
print(f" 1. 访问: {url}")
print(" 2. 下载 construction-ppe.zip")
print(f" 3. 解压到 {dataset_dir}/")
return False
print("\n[2/3] 解压中...")
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(dataset_dir)
print(" ✓ 解压完成")
except Exception as e:
print(f" ✗ 解压失败: {e}")
return False
print("\n[3/3] 清理临时文件...")
os.remove(zip_path)
print(" ✓ 完成")
return True
def create_single_class_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
"""创建单类 shoe 检测配置文件。"""
yaml_content = f"""# 单类鞋子检测数据集配置
path: {dataset_path_value}
train: images/train
val: images/val
test: images/test
nc: 1
names: ['shoe']
dataset_info:
name: {source_name}
task: detect_shoe
note: 所有鞋类子类统一映射为单一类别 shoe
"""
yaml_path = os.path.join(dataset_dir, "data.yaml")
with open(yaml_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
print(f"\n✓ 配置文件创建: {yaml_path}")
return yaml_path
def create_roi_source_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
"""创建 person+shoe ROI 源数据集配置。"""
yaml_content = f"""# 人体+鞋子 ROI 源数据集配置
path: {dataset_path_value}
train: images/train
val: images/val
test: images/test
nc: 2
names: ['person', 'shoe']
dataset_info:
name: {source_name}
task: detect_person_and_shoe_for_roi
note: 用真实 Person 框生成脚部 ROIshoe 为 ROI 内检测目标
"""
yaml_path = os.path.join(dataset_dir, "data.yaml")
with open(yaml_path, "w", encoding="utf-8") as f:
f.write(yaml_content)
print(f"\n✓ ROI 源配置文件创建: {yaml_path}")
return yaml_path
def rewrite_yaml_for_existing_splits(dataset_dir: str, source_name: str):
"""根据现有目录结构重写 data.yaml。"""
images_root = Path(dataset_dir) / "images"
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
if not split_names:
raise RuntimeError(f"未找到任何图像 split: {images_root}")
train_split = "train" if "train" in split_names else split_names[0]
val_split = "val" if "val" in split_names else train_split
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
yaml_content = f"""# 单类鞋子检测数据集配置
path: .
train: images/{train_split}
val: images/{val_split}
{test_line}
nc: 1
names: ['shoe']
dataset_info:
name: {source_name}
task: detect_shoe
note: 所有鞋类子类统一映射为单一类别 shoe
"""
yaml_path = Path(dataset_dir) / "data.yaml"
yaml_path.write_text(yaml_content, encoding="utf-8")
print(f"\n✓ 配置文件更新: {yaml_path}")
return str(yaml_path)
def rewrite_roi_yaml_for_existing_splits(dataset_dir: str, source_name: str):
"""根据现有目录结构重写 person+shoe ROI 源 data.yaml。"""
images_root = Path(dataset_dir) / "images"
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
if not split_names:
raise RuntimeError(f"未找到任何图像 split: {images_root}")
train_split = "train" if "train" in split_names else split_names[0]
val_split = "val" if "val" in split_names else train_split
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
yaml_content = f"""# 人体+鞋子 ROI 源数据集配置
path: .
train: images/{train_split}
val: images/{val_split}
{test_line}
nc: 2
names: ['person', 'shoe']
dataset_info:
name: {source_name}
task: detect_person_and_shoe_for_roi
note: 用真实 Person 框生成脚部 ROIshoe 为 ROI 内检测目标
"""
yaml_path = Path(dataset_dir) / "data.yaml"
yaml_path.write_text(yaml_content, encoding="utf-8")
print(f"\n✓ ROI 源配置文件更新: {yaml_path}")
return str(yaml_path)
def load_dataset_name_map(dataset_yaml: Path) -> dict[int, str]:
"""读取 YOLO names不依赖额外 yaml 库。"""
names: dict[int, str] = {}
lines = dataset_yaml.read_text(encoding="utf-8").splitlines()
for index, line in enumerate(lines):
stripped = line.strip()
if not stripped.startswith("names:"):
continue
inline_value = stripped[len("names:") :].strip()
if inline_value:
value = ast.literal_eval(inline_value)
if isinstance(value, list):
return {idx: str(name) for idx, name in enumerate(value)}
for child in lines[index + 1 :]:
if not child.startswith(" "):
break
child_stripped = child.strip()
if ":" not in child_stripped:
continue
key_text, value_text = child_stripped.split(":", 1)
if key_text.strip().isdigit():
names[int(key_text.strip())] = value_text.strip().strip("'\"")
break
return names
def ensure_openimages_train_val_split(export_dir: str, train_ratio: float = 0.9, seed: int = 42):
"""如果导出结果只有单个 split则自动切分为 train/val。"""
images_root = Path(export_dir) / "images"
labels_root = Path(export_dir) / "labels"
train_images = images_root / "train"
val_images = images_root / "val"
train_labels = labels_root / "train"
val_labels = labels_root / "val"
if train_images.exists() and train_labels.exists():
return
if not val_images.exists() or not val_labels.exists():
return
image_files = sorted([p for p in val_images.iterdir() if p.is_file()])
if len(image_files) < 2:
return
train_images.mkdir(parents=True, exist_ok=True)
train_labels.mkdir(parents=True, exist_ok=True)
rng = random.Random(seed)
rng.shuffle(image_files)
split_idx = max(1, min(len(image_files) - 1, int(len(image_files) * train_ratio)))
train_files = image_files[:split_idx]
for image_path in train_files:
label_path = val_labels / f"{image_path.stem}.txt"
shutil.move(str(image_path), train_images / image_path.name)
if label_path.exists():
shutil.move(str(label_path), train_labels / label_path.name)
print("\n自动切分单一 split 为 train/val:")
print(f" train: {len(list(train_images.iterdir()))}")
print(f" val: {len(list(val_images.iterdir()))}")
def merge_openimages_to_single_class(export_dir: str):
"""将 Open Images 导出的多类别 YOLO 标签合并为单类 shoe。"""
labels_root = Path(export_dir) / "labels"
if not labels_root.exists():
print(f"✗ 未找到标签目录: {labels_root}")
return False
total_files = 0
total_boxes = 0
changed_files = 0
for label_file in labels_root.rglob("*.txt"):
total_files += 1
lines = label_file.read_text(encoding="utf-8").splitlines()
rewritten = []
file_changed = False
for line in lines:
parts = line.strip().split()
if len(parts) < 5:
continue
parts[0] = "0"
rewritten.append(" ".join(parts))
total_boxes += 1
file_changed = True
label_file.write_text("\n".join(rewritten) + ("\n" if rewritten else ""), encoding="utf-8")
if file_changed:
changed_files += 1
ensure_openimages_train_val_split(export_dir)
rewrite_yaml_for_existing_splits(export_dir, "Open Images V7")
print("\n单类合并完成:")
print(f" 标签文件: {changed_files}/{total_files}")
print(f" 标注框数: {total_boxes}")
return True
def rewrite_openimages_to_roi_source(export_dir: str):
"""将 Open Images 导出的标签改写为 person+shoe ROI 源数据。"""
labels_root = Path(export_dir) / "labels"
dataset_yaml = Path(export_dir) / "dataset.yaml"
if not labels_root.exists():
print(f"✗ 未找到标签目录: {labels_root}")
return False
if not dataset_yaml.exists():
print(f"✗ 未找到 dataset.yaml: {dataset_yaml}")
return False
name_map = load_dataset_name_map(dataset_yaml)
person_ids = {idx for idx, name in name_map.items() if name.lower() == "person"}
shoe_ids = {
idx
for idx, name in name_map.items()
if name.lower() in {"footwear", "boot", "sandal", "high heels", "roller skates"}
}
if not person_ids:
print("✗ 导出结果中未找到 Person 类")
return False
if not shoe_ids:
print("✗ 导出结果中未找到鞋类")
return False
kept_files = 0
total_person = 0
total_shoe = 0
for label_file in labels_root.rglob("*.txt"):
lines = label_file.read_text(encoding="utf-8").splitlines()
rewritten: list[str] = []
file_person = 0
file_shoe = 0
for line in lines:
parts = line.strip().split()
if len(parts) < 5 or not parts[0].isdigit():
continue
class_id = int(parts[0])
if class_id in person_ids:
parts[0] = "0"
rewritten.append(" ".join(parts))
total_person += 1
file_person += 1
elif class_id in shoe_ids:
parts[0] = "1"
rewritten.append(" ".join(parts))
total_shoe += 1
file_shoe += 1
if file_person > 0 and file_shoe > 0:
label_file.write_text("\n".join(rewritten) + "\n", encoding="utf-8")
kept_files += 1
else:
label_file.write_text("", encoding="utf-8")
ensure_openimages_train_val_split(export_dir)
rewrite_roi_yaml_for_existing_splits(export_dir, "Open Images V7 ROI Source")
print("\nROI 源重写完成:")
print(f" 同时含 person + shoe 的标签文件: {kept_files}")
print(f" Person 标注框数: {total_person}")
print(f" Shoe 标注框数: {total_shoe}")
return True
def download_openimages(classes: list, max_samples: int, dataset_dir: str, mode: str):
"""通过 FiftyOne 下载 Open Images 并导出为单类或 ROI 源数据集。"""
try:
import fiftyone as fo
import fiftyone.zoo as foz
except ImportError:
print("错误: 未安装 fiftyone")
print("请运行: pip install fiftyone fiftyone-db-ubuntu2204")
return False
export_dir = dataset_dir + "-yolo"
requested_classes = list(classes)
if mode == "roi-source" and OPENIMAGES_PERSON_CLASS not in requested_classes:
requested_classes = [OPENIMAGES_PERSON_CLASS] + requested_classes
print("=" * 70)
print("下载 Open Images V7 数据集")
print("=" * 70)
print(f"模式: {mode}")
print(f"类别: {requested_classes}")
print(f"最大样本数: {max_samples}")
print("原始缓存目录: FiftyOne 默认缓存目录")
print(f"YOLO 导出目录: {export_dir}")
print()
if os.path.exists(export_dir):
print(f"检测到已有导出目录,先删除: {export_dir}")
shutil.rmtree(export_dir)
try:
dataset = foz.load_zoo_dataset(
"open-images-v7",
split="train",
label_types=["detections"],
classes=requested_classes,
max_samples=max_samples,
)
print("\n导出为 YOLO 格式...")
dataset.export(
export_dir=export_dir,
dataset_type=fo.types.YOLOv5Dataset,
label_field="ground_truth",
)
if mode == "roi-source":
if not rewrite_openimages_to_roi_source(export_dir):
return False
else:
if not merge_openimages_to_single_class(export_dir):
return False
print(f"\n✓ 数据集保存: {export_dir}")
return True
except Exception as e:
print(f"✗ 下载失败: {e}")
return False
def check_dataset(dataset_dir: str):
"""检查数据集完整性。"""
print("\n" + "=" * 70)
print("检查数据集")
print("=" * 70)
required_dirs = [
"images/train",
"images/val",
"images/test",
"labels/train",
"labels/val",
"labels/test",
]
all_ok = True
for dir_name in required_dirs:
full_path = os.path.join(dataset_dir, dir_name)
if os.path.exists(full_path):
count = len(
[f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))]
)
print(f"{dir_name}: {count} 个文件")
else:
print(f"{dir_name}: 不存在")
all_ok = False
yaml_path = os.path.join(dataset_dir, "data.yaml")
if os.path.exists(yaml_path):
print(f" ✓ data.yaml: 已生成")
else:
print(f" ✗ data.yaml: 不存在")
all_ok = False
return all_ok
def main():
parser = argparse.ArgumentParser(
description="下载鞋子检测数据集",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 下载 Construction-PPE
python 01_download_dataset.py --source ultralytics
# 下载 Open Images 推荐鞋子类别
python 01_download_dataset.py --source openimages --max-samples 5000
# 如需补充凉鞋样本
python 01_download_dataset.py --source openimages --classes Footwear Boot Sandal --max-samples 8000
""",
)
parser.add_argument(
"--source",
choices=["ultralytics", "openimages"],
default="openimages",
help="数据源 (默认: openimages)",
)
parser.add_argument(
"--dir",
default="datasets/openimages-shoes",
help="数据集保存目录",
)
parser.add_argument(
"--mode",
choices=["single-class", "roi-source"],
default="single-class",
help="导出模式: single-class 用于单类训练, roi-source 保留 person + shoe 供 ROI 构建",
)
parser.add_argument(
"--max-samples",
type=int,
default=5000,
help="Open Images 最大样本数 (默认: 5000)",
)
parser.add_argument(
"--classes",
nargs="+",
default=OPENIMAGES_RECOMMENDED_CLASSES,
help="Open Images 类别 (默认: Footwear Boot)",
)
args = parser.parse_args()
success = False
final_dataset_dir = args.dir
print("=" * 70)
print("鞋子检测数据集准备")
print("=" * 70)
if args.source == "openimages":
print(f"推荐类别: {OPENIMAGES_RECOMMENDED_CLASSES}")
print(f"可选补充: {OPENIMAGES_OPTIONAL_CLASSES}")
print(f"默认不建议: {OPENIMAGES_NOT_RECOMMENDED_CLASSES}")
if args.mode == "roi-source":
print(f"将额外保留: {OPENIMAGES_PERSON_CLASS}")
print()
if args.source == "ultralytics":
success = download_ultralytics_cppe(args.dir)
if success:
create_single_class_yaml(args.dir, "Construction-PPE", dataset_path_value="construction-ppe")
check_dataset(args.dir)
elif args.source == "openimages":
if args.mode == "roi-source" and args.dir == parser.get_default("dir"):
args.dir = ROI_SOURCE_DEFAULT_DIR
success = download_openimages(args.classes, args.max_samples, args.dir, args.mode)
final_dataset_dir = args.dir + "-yolo"
if success:
check_dataset(final_dataset_dir)
if success:
print("\n" + "=" * 70)
print("数据集准备完成!")
print("=" * 70)
print(f"训练数据集路径: {final_dataset_dir}")
print("\n下一步:")
print(f" 1. 检查配置: {final_dataset_dir}/data.yaml")
print(f" 2. 开始训练: 02_train.bat")
print(
f" 3. 或手动: yolo detect train data={final_dataset_dir}/data.yaml model=yolov8s.pt epochs=150 imgsz=640"
)
return 0
print("\n✗ 数据集准备失败")
return 1
if __name__ == "__main__":
sys.exit(main())