589 lines
18 KiB
Python
589 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
下载鞋子检测数据集。
|
||
|
||
支持:
|
||
- Ultralytics Construction-PPE
|
||
- Open Images V7 (推荐用于单类 shoe 检测)
|
||
|
||
Open Images 推荐类别:
|
||
- Footwear
|
||
- Boot
|
||
|
||
可选补充:
|
||
- Sandal
|
||
|
||
不建议默认加入:
|
||
- High heels
|
||
- Roller skates
|
||
"""
|
||
|
||
import argparse
|
||
import ast
|
||
import os
|
||
import random
|
||
import shutil
|
||
import ssl
|
||
import sys
|
||
import urllib.request
|
||
import zipfile
|
||
from pathlib import Path
|
||
|
||
|
||
OPENIMAGES_RECOMMENDED_CLASSES = ["Footwear", "Boot"]
|
||
OPENIMAGES_OPTIONAL_CLASSES = ["Sandal"]
|
||
OPENIMAGES_NOT_RECOMMENDED_CLASSES = ["High heels", "Roller skates"]
|
||
OPENIMAGES_PERSON_CLASS = "Person"
|
||
ROI_SOURCE_DEFAULT_DIR = "datasets/openimages-person-shoes"
|
||
|
||
|
||
def download_ultralytics_cppe(dataset_dir: str = "datasets/construction-ppe"):
|
||
"""下载 Ultralytics Construction-PPE 数据集。"""
|
||
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/construction-ppe.zip"
|
||
zip_path = "construction-ppe.zip"
|
||
|
||
print("=" * 70)
|
||
print("下载 Construction-PPE 数据集")
|
||
print("=" * 70)
|
||
print(f"来源: {url}")
|
||
print(f"目标: {dataset_dir}")
|
||
print()
|
||
|
||
os.makedirs(dataset_dir, exist_ok=True)
|
||
|
||
print("[1/3] 下载中... (约 178MB)")
|
||
try:
|
||
ssl_context = ssl.create_default_context()
|
||
ssl_context.check_hostname = False
|
||
ssl_context.verify_mode = ssl.CERT_NONE
|
||
|
||
with urllib.request.urlopen(url, context=ssl_context, timeout=300) as response:
|
||
total_size = int(response.headers.get("content-length", 0))
|
||
downloaded = 0
|
||
chunk_size = 8192
|
||
|
||
with open(zip_path, "wb") as f:
|
||
while True:
|
||
chunk = response.read(chunk_size)
|
||
if not chunk:
|
||
break
|
||
f.write(chunk)
|
||
downloaded += len(chunk)
|
||
if total_size > 0:
|
||
percent = (downloaded / total_size) * 100
|
||
print(
|
||
f"\r 进度: {percent:.1f}% ({downloaded}/{total_size} bytes)",
|
||
end="",
|
||
)
|
||
|
||
print("\n ✓ 下载完成")
|
||
except Exception as e:
|
||
print(f"\n ✗ 下载失败: {e}")
|
||
print("\n请手动下载:")
|
||
print(f" 1. 访问: {url}")
|
||
print(" 2. 下载 construction-ppe.zip")
|
||
print(f" 3. 解压到 {dataset_dir}/")
|
||
return False
|
||
|
||
print("\n[2/3] 解压中...")
|
||
try:
|
||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||
zip_ref.extractall(dataset_dir)
|
||
print(" ✓ 解压完成")
|
||
except Exception as e:
|
||
print(f" ✗ 解压失败: {e}")
|
||
return False
|
||
|
||
print("\n[3/3] 清理临时文件...")
|
||
os.remove(zip_path)
|
||
print(" ✓ 完成")
|
||
return True
|
||
|
||
|
||
def create_single_class_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
|
||
"""创建单类 shoe 检测配置文件。"""
|
||
yaml_content = f"""# 单类鞋子检测数据集配置
|
||
|
||
path: {dataset_path_value}
|
||
train: images/train
|
||
val: images/val
|
||
test: images/test
|
||
|
||
nc: 1
|
||
names: ['shoe']
|
||
|
||
dataset_info:
|
||
name: {source_name}
|
||
task: detect_shoe
|
||
note: 所有鞋类子类统一映射为单一类别 shoe
|
||
"""
|
||
|
||
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||
with open(yaml_path, "w", encoding="utf-8") as f:
|
||
f.write(yaml_content)
|
||
|
||
print(f"\n✓ 配置文件创建: {yaml_path}")
|
||
return yaml_path
|
||
|
||
|
||
def create_roi_source_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."):
|
||
"""创建 person+shoe ROI 源数据集配置。"""
|
||
yaml_content = f"""# 人体+鞋子 ROI 源数据集配置
|
||
|
||
path: {dataset_path_value}
|
||
train: images/train
|
||
val: images/val
|
||
test: images/test
|
||
|
||
nc: 2
|
||
names: ['person', 'shoe']
|
||
|
||
dataset_info:
|
||
name: {source_name}
|
||
task: detect_person_and_shoe_for_roi
|
||
note: 用真实 Person 框生成脚部 ROI,shoe 为 ROI 内检测目标
|
||
"""
|
||
|
||
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||
with open(yaml_path, "w", encoding="utf-8") as f:
|
||
f.write(yaml_content)
|
||
|
||
print(f"\n✓ ROI 源配置文件创建: {yaml_path}")
|
||
return yaml_path
|
||
|
||
|
||
def rewrite_yaml_for_existing_splits(dataset_dir: str, source_name: str):
|
||
"""根据现有目录结构重写 data.yaml。"""
|
||
images_root = Path(dataset_dir) / "images"
|
||
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
|
||
|
||
if not split_names:
|
||
raise RuntimeError(f"未找到任何图像 split: {images_root}")
|
||
|
||
train_split = "train" if "train" in split_names else split_names[0]
|
||
val_split = "val" if "val" in split_names else train_split
|
||
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
|
||
|
||
yaml_content = f"""# 单类鞋子检测数据集配置
|
||
|
||
path: .
|
||
train: images/{train_split}
|
||
val: images/{val_split}
|
||
{test_line}
|
||
|
||
nc: 1
|
||
names: ['shoe']
|
||
|
||
dataset_info:
|
||
name: {source_name}
|
||
task: detect_shoe
|
||
note: 所有鞋类子类统一映射为单一类别 shoe
|
||
"""
|
||
|
||
yaml_path = Path(dataset_dir) / "data.yaml"
|
||
yaml_path.write_text(yaml_content, encoding="utf-8")
|
||
print(f"\n✓ 配置文件更新: {yaml_path}")
|
||
return str(yaml_path)
|
||
|
||
|
||
def rewrite_roi_yaml_for_existing_splits(dataset_dir: str, source_name: str):
|
||
"""根据现有目录结构重写 person+shoe ROI 源 data.yaml。"""
|
||
images_root = Path(dataset_dir) / "images"
|
||
split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()]
|
||
|
||
if not split_names:
|
||
raise RuntimeError(f"未找到任何图像 split: {images_root}")
|
||
|
||
train_split = "train" if "train" in split_names else split_names[0]
|
||
val_split = "val" if "val" in split_names else train_split
|
||
test_line = f"test: images/{'test' if 'test' in split_names else val_split}"
|
||
|
||
yaml_content = f"""# 人体+鞋子 ROI 源数据集配置
|
||
|
||
path: .
|
||
train: images/{train_split}
|
||
val: images/{val_split}
|
||
{test_line}
|
||
|
||
nc: 2
|
||
names: ['person', 'shoe']
|
||
|
||
dataset_info:
|
||
name: {source_name}
|
||
task: detect_person_and_shoe_for_roi
|
||
note: 用真实 Person 框生成脚部 ROI,shoe 为 ROI 内检测目标
|
||
"""
|
||
|
||
yaml_path = Path(dataset_dir) / "data.yaml"
|
||
yaml_path.write_text(yaml_content, encoding="utf-8")
|
||
print(f"\n✓ ROI 源配置文件更新: {yaml_path}")
|
||
return str(yaml_path)
|
||
|
||
|
||
def load_dataset_name_map(dataset_yaml: Path) -> dict[int, str]:
|
||
"""读取 YOLO names,不依赖额外 yaml 库。"""
|
||
names: dict[int, str] = {}
|
||
lines = dataset_yaml.read_text(encoding="utf-8").splitlines()
|
||
for index, line in enumerate(lines):
|
||
stripped = line.strip()
|
||
if not stripped.startswith("names:"):
|
||
continue
|
||
|
||
inline_value = stripped[len("names:") :].strip()
|
||
if inline_value:
|
||
value = ast.literal_eval(inline_value)
|
||
if isinstance(value, list):
|
||
return {idx: str(name) for idx, name in enumerate(value)}
|
||
|
||
for child in lines[index + 1 :]:
|
||
if not child.startswith(" "):
|
||
break
|
||
child_stripped = child.strip()
|
||
if ":" not in child_stripped:
|
||
continue
|
||
key_text, value_text = child_stripped.split(":", 1)
|
||
if key_text.strip().isdigit():
|
||
names[int(key_text.strip())] = value_text.strip().strip("'\"")
|
||
break
|
||
return names
|
||
|
||
|
||
def ensure_openimages_train_val_split(export_dir: str, train_ratio: float = 0.9, seed: int = 42):
|
||
"""如果导出结果只有单个 split,则自动切分为 train/val。"""
|
||
images_root = Path(export_dir) / "images"
|
||
labels_root = Path(export_dir) / "labels"
|
||
train_images = images_root / "train"
|
||
val_images = images_root / "val"
|
||
train_labels = labels_root / "train"
|
||
val_labels = labels_root / "val"
|
||
|
||
if train_images.exists() and train_labels.exists():
|
||
return
|
||
|
||
if not val_images.exists() or not val_labels.exists():
|
||
return
|
||
|
||
image_files = sorted([p for p in val_images.iterdir() if p.is_file()])
|
||
if len(image_files) < 2:
|
||
return
|
||
|
||
train_images.mkdir(parents=True, exist_ok=True)
|
||
train_labels.mkdir(parents=True, exist_ok=True)
|
||
|
||
rng = random.Random(seed)
|
||
rng.shuffle(image_files)
|
||
|
||
split_idx = max(1, min(len(image_files) - 1, int(len(image_files) * train_ratio)))
|
||
train_files = image_files[:split_idx]
|
||
|
||
for image_path in train_files:
|
||
label_path = val_labels / f"{image_path.stem}.txt"
|
||
shutil.move(str(image_path), train_images / image_path.name)
|
||
if label_path.exists():
|
||
shutil.move(str(label_path), train_labels / label_path.name)
|
||
|
||
print("\n自动切分单一 split 为 train/val:")
|
||
print(f" train: {len(list(train_images.iterdir()))} 张")
|
||
print(f" val: {len(list(val_images.iterdir()))} 张")
|
||
|
||
|
||
def merge_openimages_to_single_class(export_dir: str):
|
||
"""将 Open Images 导出的多类别 YOLO 标签合并为单类 shoe。"""
|
||
labels_root = Path(export_dir) / "labels"
|
||
if not labels_root.exists():
|
||
print(f"✗ 未找到标签目录: {labels_root}")
|
||
return False
|
||
|
||
total_files = 0
|
||
total_boxes = 0
|
||
changed_files = 0
|
||
|
||
for label_file in labels_root.rglob("*.txt"):
|
||
total_files += 1
|
||
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||
rewritten = []
|
||
file_changed = False
|
||
|
||
for line in lines:
|
||
parts = line.strip().split()
|
||
if len(parts) < 5:
|
||
continue
|
||
parts[0] = "0"
|
||
rewritten.append(" ".join(parts))
|
||
total_boxes += 1
|
||
file_changed = True
|
||
|
||
label_file.write_text("\n".join(rewritten) + ("\n" if rewritten else ""), encoding="utf-8")
|
||
|
||
if file_changed:
|
||
changed_files += 1
|
||
|
||
ensure_openimages_train_val_split(export_dir)
|
||
rewrite_yaml_for_existing_splits(export_dir, "Open Images V7")
|
||
|
||
print("\n单类合并完成:")
|
||
print(f" 标签文件: {changed_files}/{total_files}")
|
||
print(f" 标注框数: {total_boxes}")
|
||
return True
|
||
|
||
|
||
def rewrite_openimages_to_roi_source(export_dir: str):
|
||
"""将 Open Images 导出的标签改写为 person+shoe ROI 源数据。"""
|
||
labels_root = Path(export_dir) / "labels"
|
||
dataset_yaml = Path(export_dir) / "dataset.yaml"
|
||
if not labels_root.exists():
|
||
print(f"✗ 未找到标签目录: {labels_root}")
|
||
return False
|
||
if not dataset_yaml.exists():
|
||
print(f"✗ 未找到 dataset.yaml: {dataset_yaml}")
|
||
return False
|
||
|
||
name_map = load_dataset_name_map(dataset_yaml)
|
||
person_ids = {idx for idx, name in name_map.items() if name.lower() == "person"}
|
||
shoe_ids = {
|
||
idx
|
||
for idx, name in name_map.items()
|
||
if name.lower() in {"footwear", "boot", "sandal", "high heels", "roller skates"}
|
||
}
|
||
|
||
if not person_ids:
|
||
print("✗ 导出结果中未找到 Person 类")
|
||
return False
|
||
if not shoe_ids:
|
||
print("✗ 导出结果中未找到鞋类")
|
||
return False
|
||
|
||
kept_files = 0
|
||
total_person = 0
|
||
total_shoe = 0
|
||
|
||
for label_file in labels_root.rglob("*.txt"):
|
||
lines = label_file.read_text(encoding="utf-8").splitlines()
|
||
rewritten: list[str] = []
|
||
file_person = 0
|
||
file_shoe = 0
|
||
|
||
for line in lines:
|
||
parts = line.strip().split()
|
||
if len(parts) < 5 or not parts[0].isdigit():
|
||
continue
|
||
class_id = int(parts[0])
|
||
if class_id in person_ids:
|
||
parts[0] = "0"
|
||
rewritten.append(" ".join(parts))
|
||
total_person += 1
|
||
file_person += 1
|
||
elif class_id in shoe_ids:
|
||
parts[0] = "1"
|
||
rewritten.append(" ".join(parts))
|
||
total_shoe += 1
|
||
file_shoe += 1
|
||
|
||
if file_person > 0 and file_shoe > 0:
|
||
label_file.write_text("\n".join(rewritten) + "\n", encoding="utf-8")
|
||
kept_files += 1
|
||
else:
|
||
label_file.write_text("", encoding="utf-8")
|
||
|
||
ensure_openimages_train_val_split(export_dir)
|
||
rewrite_roi_yaml_for_existing_splits(export_dir, "Open Images V7 ROI Source")
|
||
|
||
print("\nROI 源重写完成:")
|
||
print(f" 同时含 person + shoe 的标签文件: {kept_files}")
|
||
print(f" Person 标注框数: {total_person}")
|
||
print(f" Shoe 标注框数: {total_shoe}")
|
||
return True
|
||
|
||
|
||
def download_openimages(classes: list, max_samples: int, dataset_dir: str, mode: str):
|
||
"""通过 FiftyOne 下载 Open Images 并导出为单类或 ROI 源数据集。"""
|
||
try:
|
||
import fiftyone as fo
|
||
import fiftyone.zoo as foz
|
||
except ImportError:
|
||
print("错误: 未安装 fiftyone")
|
||
print("请运行: pip install fiftyone fiftyone-db-ubuntu2204")
|
||
return False
|
||
|
||
export_dir = dataset_dir + "-yolo"
|
||
requested_classes = list(classes)
|
||
if mode == "roi-source" and OPENIMAGES_PERSON_CLASS not in requested_classes:
|
||
requested_classes = [OPENIMAGES_PERSON_CLASS] + requested_classes
|
||
|
||
print("=" * 70)
|
||
print("下载 Open Images V7 数据集")
|
||
print("=" * 70)
|
||
print(f"模式: {mode}")
|
||
print(f"类别: {requested_classes}")
|
||
print(f"最大样本数: {max_samples}")
|
||
print("原始缓存目录: FiftyOne 默认缓存目录")
|
||
print(f"YOLO 导出目录: {export_dir}")
|
||
print()
|
||
|
||
if os.path.exists(export_dir):
|
||
print(f"检测到已有导出目录,先删除: {export_dir}")
|
||
shutil.rmtree(export_dir)
|
||
|
||
try:
|
||
dataset = foz.load_zoo_dataset(
|
||
"open-images-v7",
|
||
split="train",
|
||
label_types=["detections"],
|
||
classes=requested_classes,
|
||
max_samples=max_samples,
|
||
)
|
||
|
||
print("\n导出为 YOLO 格式...")
|
||
dataset.export(
|
||
export_dir=export_dir,
|
||
dataset_type=fo.types.YOLOv5Dataset,
|
||
label_field="ground_truth",
|
||
)
|
||
|
||
if mode == "roi-source":
|
||
if not rewrite_openimages_to_roi_source(export_dir):
|
||
return False
|
||
else:
|
||
if not merge_openimages_to_single_class(export_dir):
|
||
return False
|
||
|
||
print(f"\n✓ 数据集保存: {export_dir}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"✗ 下载失败: {e}")
|
||
return False
|
||
|
||
|
||
def check_dataset(dataset_dir: str):
|
||
"""检查数据集完整性。"""
|
||
print("\n" + "=" * 70)
|
||
print("检查数据集")
|
||
print("=" * 70)
|
||
|
||
required_dirs = [
|
||
"images/train",
|
||
"images/val",
|
||
"images/test",
|
||
"labels/train",
|
||
"labels/val",
|
||
"labels/test",
|
||
]
|
||
|
||
all_ok = True
|
||
for dir_name in required_dirs:
|
||
full_path = os.path.join(dataset_dir, dir_name)
|
||
if os.path.exists(full_path):
|
||
count = len(
|
||
[f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))]
|
||
)
|
||
print(f" ✓ {dir_name}: {count} 个文件")
|
||
else:
|
||
print(f" ✗ {dir_name}: 不存在")
|
||
all_ok = False
|
||
|
||
yaml_path = os.path.join(dataset_dir, "data.yaml")
|
||
if os.path.exists(yaml_path):
|
||
print(f" ✓ data.yaml: 已生成")
|
||
else:
|
||
print(f" ✗ data.yaml: 不存在")
|
||
all_ok = False
|
||
|
||
return all_ok
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="下载鞋子检测数据集",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
# 下载 Construction-PPE
|
||
python 01_download_dataset.py --source ultralytics
|
||
|
||
# 下载 Open Images 推荐鞋子类别
|
||
python 01_download_dataset.py --source openimages --max-samples 5000
|
||
|
||
# 如需补充凉鞋样本
|
||
python 01_download_dataset.py --source openimages --classes Footwear Boot Sandal --max-samples 8000
|
||
""",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--source",
|
||
choices=["ultralytics", "openimages"],
|
||
default="openimages",
|
||
help="数据源 (默认: openimages)",
|
||
)
|
||
parser.add_argument(
|
||
"--dir",
|
||
default="datasets/openimages-shoes",
|
||
help="数据集保存目录",
|
||
)
|
||
parser.add_argument(
|
||
"--mode",
|
||
choices=["single-class", "roi-source"],
|
||
default="single-class",
|
||
help="导出模式: single-class 用于单类训练, roi-source 保留 person + shoe 供 ROI 构建",
|
||
)
|
||
parser.add_argument(
|
||
"--max-samples",
|
||
type=int,
|
||
default=5000,
|
||
help="Open Images 最大样本数 (默认: 5000)",
|
||
)
|
||
parser.add_argument(
|
||
"--classes",
|
||
nargs="+",
|
||
default=OPENIMAGES_RECOMMENDED_CLASSES,
|
||
help="Open Images 类别 (默认: Footwear Boot)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
success = False
|
||
final_dataset_dir = args.dir
|
||
|
||
print("=" * 70)
|
||
print("鞋子检测数据集准备")
|
||
print("=" * 70)
|
||
if args.source == "openimages":
|
||
print(f"推荐类别: {OPENIMAGES_RECOMMENDED_CLASSES}")
|
||
print(f"可选补充: {OPENIMAGES_OPTIONAL_CLASSES}")
|
||
print(f"默认不建议: {OPENIMAGES_NOT_RECOMMENDED_CLASSES}")
|
||
if args.mode == "roi-source":
|
||
print(f"将额外保留: {OPENIMAGES_PERSON_CLASS}")
|
||
print()
|
||
|
||
if args.source == "ultralytics":
|
||
success = download_ultralytics_cppe(args.dir)
|
||
if success:
|
||
create_single_class_yaml(args.dir, "Construction-PPE", dataset_path_value="construction-ppe")
|
||
check_dataset(args.dir)
|
||
|
||
elif args.source == "openimages":
|
||
if args.mode == "roi-source" and args.dir == parser.get_default("dir"):
|
||
args.dir = ROI_SOURCE_DEFAULT_DIR
|
||
success = download_openimages(args.classes, args.max_samples, args.dir, args.mode)
|
||
final_dataset_dir = args.dir + "-yolo"
|
||
if success:
|
||
check_dataset(final_dataset_dir)
|
||
|
||
if success:
|
||
print("\n" + "=" * 70)
|
||
print("数据集准备完成!")
|
||
print("=" * 70)
|
||
print(f"训练数据集路径: {final_dataset_dir}")
|
||
print("\n下一步:")
|
||
print(f" 1. 检查配置: {final_dataset_dir}/data.yaml")
|
||
print(f" 2. 开始训练: 02_train.bat")
|
||
print(
|
||
f" 3. 或手动: yolo detect train data={final_dataset_dir}/data.yaml model=yolov8s.pt epochs=150 imgsz=640"
|
||
)
|
||
return 0
|
||
|
||
print("\n✗ 数据集准备失败")
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|