#!/usr/bin/env python3 """ 下载鞋子检测数据集。 支持: - Ultralytics Construction-PPE - Open Images V7 (推荐用于单类 shoe 检测) Open Images 推荐类别: - Footwear - Boot 可选补充: - Sandal 不建议默认加入: - High heels - Roller skates """ import argparse import ast import os import random import shutil import ssl import sys import urllib.request import zipfile from pathlib import Path OPENIMAGES_RECOMMENDED_CLASSES = ["Footwear", "Boot"] OPENIMAGES_OPTIONAL_CLASSES = ["Sandal"] OPENIMAGES_NOT_RECOMMENDED_CLASSES = ["High heels", "Roller skates"] OPENIMAGES_PERSON_CLASS = "Person" ROI_SOURCE_DEFAULT_DIR = "datasets/openimages-person-shoes" def download_ultralytics_cppe(dataset_dir: str = "datasets/construction-ppe"): """下载 Ultralytics Construction-PPE 数据集。""" url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/construction-ppe.zip" zip_path = "construction-ppe.zip" print("=" * 70) print("下载 Construction-PPE 数据集") print("=" * 70) print(f"来源: {url}") print(f"目标: {dataset_dir}") print() os.makedirs(dataset_dir, exist_ok=True) print("[1/3] 下载中... (约 178MB)") try: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE with urllib.request.urlopen(url, context=ssl_context, timeout=300) as response: total_size = int(response.headers.get("content-length", 0)) downloaded = 0 chunk_size = 8192 with open(zip_path, "wb") as f: while True: chunk = response.read(chunk_size) if not chunk: break f.write(chunk) downloaded += len(chunk) if total_size > 0: percent = (downloaded / total_size) * 100 print( f"\r 进度: {percent:.1f}% ({downloaded}/{total_size} bytes)", end="", ) print("\n ✓ 下载完成") except Exception as e: print(f"\n ✗ 下载失败: {e}") print("\n请手动下载:") print(f" 1. 访问: {url}") print(" 2. 下载 construction-ppe.zip") print(f" 3. 解压到 {dataset_dir}/") return False print("\n[2/3] 解压中...") try: with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(dataset_dir) print(" ✓ 解压完成") except Exception as e: print(f" ✗ 解压失败: {e}") return False print("\n[3/3] 清理临时文件...") os.remove(zip_path) print(" ✓ 完成") return True def create_single_class_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."): """创建单类 shoe 检测配置文件。""" yaml_content = f"""# 单类鞋子检测数据集配置 path: {dataset_path_value} train: images/train val: images/val test: images/test nc: 1 names: ['shoe'] dataset_info: name: {source_name} task: detect_shoe note: 所有鞋类子类统一映射为单一类别 shoe """ yaml_path = os.path.join(dataset_dir, "data.yaml") with open(yaml_path, "w", encoding="utf-8") as f: f.write(yaml_content) print(f"\n✓ 配置文件创建: {yaml_path}") return yaml_path def create_roi_source_yaml(dataset_dir: str, source_name: str, dataset_path_value: str = "."): """创建 person+shoe ROI 源数据集配置。""" yaml_content = f"""# 人体+鞋子 ROI 源数据集配置 path: {dataset_path_value} train: images/train val: images/val test: images/test nc: 2 names: ['person', 'shoe'] dataset_info: name: {source_name} task: detect_person_and_shoe_for_roi note: 用真实 Person 框生成脚部 ROI,shoe 为 ROI 内检测目标 """ yaml_path = os.path.join(dataset_dir, "data.yaml") with open(yaml_path, "w", encoding="utf-8") as f: f.write(yaml_content) print(f"\n✓ ROI 源配置文件创建: {yaml_path}") return yaml_path def rewrite_yaml_for_existing_splits(dataset_dir: str, source_name: str): """根据现有目录结构重写 data.yaml。""" images_root = Path(dataset_dir) / "images" split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()] if not split_names: raise RuntimeError(f"未找到任何图像 split: {images_root}") train_split = "train" if "train" in split_names else split_names[0] val_split = "val" if "val" in split_names else train_split test_line = f"test: images/{'test' if 'test' in split_names else val_split}" yaml_content = f"""# 单类鞋子检测数据集配置 path: . train: images/{train_split} val: images/{val_split} {test_line} nc: 1 names: ['shoe'] dataset_info: name: {source_name} task: detect_shoe note: 所有鞋类子类统一映射为单一类别 shoe """ yaml_path = Path(dataset_dir) / "data.yaml" yaml_path.write_text(yaml_content, encoding="utf-8") print(f"\n✓ 配置文件更新: {yaml_path}") return str(yaml_path) def rewrite_roi_yaml_for_existing_splits(dataset_dir: str, source_name: str): """根据现有目录结构重写 person+shoe ROI 源 data.yaml。""" images_root = Path(dataset_dir) / "images" split_names = [name for name in ("train", "val", "test") if (images_root / name).exists()] if not split_names: raise RuntimeError(f"未找到任何图像 split: {images_root}") train_split = "train" if "train" in split_names else split_names[0] val_split = "val" if "val" in split_names else train_split test_line = f"test: images/{'test' if 'test' in split_names else val_split}" yaml_content = f"""# 人体+鞋子 ROI 源数据集配置 path: . train: images/{train_split} val: images/{val_split} {test_line} nc: 2 names: ['person', 'shoe'] dataset_info: name: {source_name} task: detect_person_and_shoe_for_roi note: 用真实 Person 框生成脚部 ROI,shoe 为 ROI 内检测目标 """ yaml_path = Path(dataset_dir) / "data.yaml" yaml_path.write_text(yaml_content, encoding="utf-8") print(f"\n✓ ROI 源配置文件更新: {yaml_path}") return str(yaml_path) def load_dataset_name_map(dataset_yaml: Path) -> dict[int, str]: """读取 YOLO names,不依赖额外 yaml 库。""" names: dict[int, str] = {} lines = dataset_yaml.read_text(encoding="utf-8").splitlines() for index, line in enumerate(lines): stripped = line.strip() if not stripped.startswith("names:"): continue inline_value = stripped[len("names:") :].strip() if inline_value: value = ast.literal_eval(inline_value) if isinstance(value, list): return {idx: str(name) for idx, name in enumerate(value)} for child in lines[index + 1 :]: if not child.startswith(" "): break child_stripped = child.strip() if ":" not in child_stripped: continue key_text, value_text = child_stripped.split(":", 1) if key_text.strip().isdigit(): names[int(key_text.strip())] = value_text.strip().strip("'\"") break return names def ensure_openimages_train_val_split(export_dir: str, train_ratio: float = 0.9, seed: int = 42): """如果导出结果只有单个 split,则自动切分为 train/val。""" images_root = Path(export_dir) / "images" labels_root = Path(export_dir) / "labels" train_images = images_root / "train" val_images = images_root / "val" train_labels = labels_root / "train" val_labels = labels_root / "val" if train_images.exists() and train_labels.exists(): return if not val_images.exists() or not val_labels.exists(): return image_files = sorted([p for p in val_images.iterdir() if p.is_file()]) if len(image_files) < 2: return train_images.mkdir(parents=True, exist_ok=True) train_labels.mkdir(parents=True, exist_ok=True) rng = random.Random(seed) rng.shuffle(image_files) split_idx = max(1, min(len(image_files) - 1, int(len(image_files) * train_ratio))) train_files = image_files[:split_idx] for image_path in train_files: label_path = val_labels / f"{image_path.stem}.txt" shutil.move(str(image_path), train_images / image_path.name) if label_path.exists(): shutil.move(str(label_path), train_labels / label_path.name) print("\n自动切分单一 split 为 train/val:") print(f" train: {len(list(train_images.iterdir()))} 张") print(f" val: {len(list(val_images.iterdir()))} 张") def merge_openimages_to_single_class(export_dir: str): """将 Open Images 导出的多类别 YOLO 标签合并为单类 shoe。""" labels_root = Path(export_dir) / "labels" if not labels_root.exists(): print(f"✗ 未找到标签目录: {labels_root}") return False total_files = 0 total_boxes = 0 changed_files = 0 for label_file in labels_root.rglob("*.txt"): total_files += 1 lines = label_file.read_text(encoding="utf-8").splitlines() rewritten = [] file_changed = False for line in lines: parts = line.strip().split() if len(parts) < 5: continue parts[0] = "0" rewritten.append(" ".join(parts)) total_boxes += 1 file_changed = True label_file.write_text("\n".join(rewritten) + ("\n" if rewritten else ""), encoding="utf-8") if file_changed: changed_files += 1 ensure_openimages_train_val_split(export_dir) rewrite_yaml_for_existing_splits(export_dir, "Open Images V7") print("\n单类合并完成:") print(f" 标签文件: {changed_files}/{total_files}") print(f" 标注框数: {total_boxes}") return True def rewrite_openimages_to_roi_source(export_dir: str): """将 Open Images 导出的标签改写为 person+shoe ROI 源数据。""" labels_root = Path(export_dir) / "labels" dataset_yaml = Path(export_dir) / "dataset.yaml" if not labels_root.exists(): print(f"✗ 未找到标签目录: {labels_root}") return False if not dataset_yaml.exists(): print(f"✗ 未找到 dataset.yaml: {dataset_yaml}") return False name_map = load_dataset_name_map(dataset_yaml) person_ids = {idx for idx, name in name_map.items() if name.lower() == "person"} shoe_ids = { idx for idx, name in name_map.items() if name.lower() in {"footwear", "boot", "sandal", "high heels", "roller skates"} } if not person_ids: print("✗ 导出结果中未找到 Person 类") return False if not shoe_ids: print("✗ 导出结果中未找到鞋类") return False kept_files = 0 total_person = 0 total_shoe = 0 for label_file in labels_root.rglob("*.txt"): lines = label_file.read_text(encoding="utf-8").splitlines() rewritten: list[str] = [] file_person = 0 file_shoe = 0 for line in lines: parts = line.strip().split() if len(parts) < 5 or not parts[0].isdigit(): continue class_id = int(parts[0]) if class_id in person_ids: parts[0] = "0" rewritten.append(" ".join(parts)) total_person += 1 file_person += 1 elif class_id in shoe_ids: parts[0] = "1" rewritten.append(" ".join(parts)) total_shoe += 1 file_shoe += 1 if file_person > 0 and file_shoe > 0: label_file.write_text("\n".join(rewritten) + "\n", encoding="utf-8") kept_files += 1 else: label_file.write_text("", encoding="utf-8") ensure_openimages_train_val_split(export_dir) rewrite_roi_yaml_for_existing_splits(export_dir, "Open Images V7 ROI Source") print("\nROI 源重写完成:") print(f" 同时含 person + shoe 的标签文件: {kept_files}") print(f" Person 标注框数: {total_person}") print(f" Shoe 标注框数: {total_shoe}") return True def download_openimages(classes: list, max_samples: int, dataset_dir: str, mode: str): """通过 FiftyOne 下载 Open Images 并导出为单类或 ROI 源数据集。""" try: import fiftyone as fo import fiftyone.zoo as foz except ImportError: print("错误: 未安装 fiftyone") print("请运行: pip install fiftyone fiftyone-db-ubuntu2204") return False export_dir = dataset_dir + "-yolo" requested_classes = list(classes) if mode == "roi-source" and OPENIMAGES_PERSON_CLASS not in requested_classes: requested_classes = [OPENIMAGES_PERSON_CLASS] + requested_classes print("=" * 70) print("下载 Open Images V7 数据集") print("=" * 70) print(f"模式: {mode}") print(f"类别: {requested_classes}") print(f"最大样本数: {max_samples}") print("原始缓存目录: FiftyOne 默认缓存目录") print(f"YOLO 导出目录: {export_dir}") print() if os.path.exists(export_dir): print(f"检测到已有导出目录,先删除: {export_dir}") shutil.rmtree(export_dir) try: dataset = foz.load_zoo_dataset( "open-images-v7", split="train", label_types=["detections"], classes=requested_classes, max_samples=max_samples, ) print("\n导出为 YOLO 格式...") dataset.export( export_dir=export_dir, dataset_type=fo.types.YOLOv5Dataset, label_field="ground_truth", ) if mode == "roi-source": if not rewrite_openimages_to_roi_source(export_dir): return False else: if not merge_openimages_to_single_class(export_dir): return False print(f"\n✓ 数据集保存: {export_dir}") return True except Exception as e: print(f"✗ 下载失败: {e}") return False def check_dataset(dataset_dir: str): """检查数据集完整性。""" print("\n" + "=" * 70) print("检查数据集") print("=" * 70) required_dirs = [ "images/train", "images/val", "images/test", "labels/train", "labels/val", "labels/test", ] all_ok = True for dir_name in required_dirs: full_path = os.path.join(dataset_dir, dir_name) if os.path.exists(full_path): count = len( [f for f in os.listdir(full_path) if os.path.isfile(os.path.join(full_path, f))] ) print(f" ✓ {dir_name}: {count} 个文件") else: print(f" ✗ {dir_name}: 不存在") all_ok = False yaml_path = os.path.join(dataset_dir, "data.yaml") if os.path.exists(yaml_path): print(f" ✓ data.yaml: 已生成") else: print(f" ✗ data.yaml: 不存在") all_ok = False return all_ok def main(): parser = argparse.ArgumentParser( description="下载鞋子检测数据集", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: # 下载 Construction-PPE python 01_download_dataset.py --source ultralytics # 下载 Open Images 推荐鞋子类别 python 01_download_dataset.py --source openimages --max-samples 5000 # 如需补充凉鞋样本 python 01_download_dataset.py --source openimages --classes Footwear Boot Sandal --max-samples 8000 """, ) parser.add_argument( "--source", choices=["ultralytics", "openimages"], default="openimages", help="数据源 (默认: openimages)", ) parser.add_argument( "--dir", default="datasets/openimages-shoes", help="数据集保存目录", ) parser.add_argument( "--mode", choices=["single-class", "roi-source"], default="single-class", help="导出模式: single-class 用于单类训练, roi-source 保留 person + shoe 供 ROI 构建", ) parser.add_argument( "--max-samples", type=int, default=5000, help="Open Images 最大样本数 (默认: 5000)", ) parser.add_argument( "--classes", nargs="+", default=OPENIMAGES_RECOMMENDED_CLASSES, help="Open Images 类别 (默认: Footwear Boot)", ) args = parser.parse_args() success = False final_dataset_dir = args.dir print("=" * 70) print("鞋子检测数据集准备") print("=" * 70) if args.source == "openimages": print(f"推荐类别: {OPENIMAGES_RECOMMENDED_CLASSES}") print(f"可选补充: {OPENIMAGES_OPTIONAL_CLASSES}") print(f"默认不建议: {OPENIMAGES_NOT_RECOMMENDED_CLASSES}") if args.mode == "roi-source": print(f"将额外保留: {OPENIMAGES_PERSON_CLASS}") print() if args.source == "ultralytics": success = download_ultralytics_cppe(args.dir) if success: create_single_class_yaml(args.dir, "Construction-PPE", dataset_path_value="construction-ppe") check_dataset(args.dir) elif args.source == "openimages": if args.mode == "roi-source" and args.dir == parser.get_default("dir"): args.dir = ROI_SOURCE_DEFAULT_DIR success = download_openimages(args.classes, args.max_samples, args.dir, args.mode) final_dataset_dir = args.dir + "-yolo" if success: check_dataset(final_dataset_dir) if success: print("\n" + "=" * 70) print("数据集准备完成!") print("=" * 70) print(f"训练数据集路径: {final_dataset_dir}") print("\n下一步:") print(f" 1. 检查配置: {final_dataset_dir}/data.yaml") print(f" 2. 开始训练: 02_train.bat") print( f" 3. 或手动: yolo detect train data={final_dataset_dir}/data.yaml model=yolov8s.pt epochs=150 imgsz=640" ) return 0 print("\n✗ 数据集准备失败") return 1 if __name__ == "__main__": sys.exit(main())