DetectionModelTraining/05_prepare_ppe_shoe_subset.py

175 lines
5.3 KiB
Python

#!/usr/bin/env python3
"""从 Construction-PPE 中提取单类 shoe 或 person+shoe ROI 源数据。"""
import argparse
import shutil
from pathlib import Path
SHOE_CLASSES = {"3", "10"} # boots, no_boots
PERSON_CLASSES = {"6"} # Person
def ensure_clean_dir(path: Path):
if path.exists():
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
def write_shoe_yaml(output_dir: Path):
yaml_path = output_dir / "data.yaml"
abs_output = output_dir.resolve().as_posix()
yaml_path.write_text(
"\n".join(
[
"# PPE 鞋子单类微调数据集配置",
"",
f"path: {abs_output}",
"train: images/train",
"val: images/val",
"test: images/test",
"",
"nc: 1",
"names: ['shoe']",
"",
"dataset_info:",
" name: Construction-PPE shoe subset",
" source: Construction-PPE",
" note: 从 boots 和 no_boots 提取并统一映射为 shoe",
"",
]
),
encoding="utf-8",
)
def write_roi_source_yaml(output_dir: Path):
yaml_path = output_dir / "data.yaml"
abs_output = output_dir.resolve().as_posix()
yaml_path.write_text(
"\n".join(
[
"# PPE 人体+鞋子 ROI 源数据配置",
"",
f"path: {abs_output}",
"train: images/train",
"val: images/val",
"test: images/test",
"",
"nc: 2",
"names: ['person', 'shoe']",
"",
"dataset_info:",
" name: Construction-PPE ROI source",
" source: Construction-PPE",
" note: 保留 Person 和鞋类,用真实人框生成脚部 ROI",
"",
]
),
encoding="utf-8",
)
def convert_split(source_dir: Path, output_dir: Path, split: str, mode: str):
image_src = source_dir / "images" / split
label_src = source_dir / "labels" / split
image_dst = output_dir / "images" / split
label_dst = output_dir / "labels" / split
image_dst.mkdir(parents=True, exist_ok=True)
label_dst.mkdir(parents=True, exist_ok=True)
kept_images = 0
kept_boxes = 0
for label_file in sorted(label_src.glob("*.txt")):
lines = label_file.read_text(encoding="utf-8").splitlines()
out_lines = []
file_person = 0
file_shoe = 0
for line in lines:
parts = line.strip().split()
if len(parts) < 5:
continue
class_id = parts[0]
if mode == "roi-source":
if class_id in PERSON_CLASSES:
parts[0] = "0"
out_lines.append(" ".join(parts))
file_person += 1
elif class_id in SHOE_CLASSES:
parts[0] = "1"
out_lines.append(" ".join(parts))
file_shoe += 1
elif class_id in SHOE_CLASSES:
parts[0] = "0"
out_lines.append(" ".join(parts))
file_shoe += 1
if mode == "roi-source" and (file_person == 0 or file_shoe == 0):
continue
if mode != "roi-source" and file_shoe == 0:
continue
image_file = image_src / f"{label_file.stem}.jpg"
if not image_file.exists():
image_file = image_src / f"{label_file.stem}.png"
if not image_file.exists():
continue
shutil.copy2(image_file, image_dst / image_file.name)
(label_dst / label_file.name).write_text("\n".join(out_lines) + "\n", encoding="utf-8")
kept_images += 1
kept_boxes += len(out_lines)
return kept_images, kept_boxes
def main():
parser = argparse.ArgumentParser(description="提取 PPE 鞋子单类子集")
parser.add_argument(
"--source",
default="datasets/construction-ppe",
help="Construction-PPE 数据集目录",
)
parser.add_argument(
"--output",
default="datasets/ppe-shoes",
help="输出目录",
)
parser.add_argument(
"--mode",
choices=["shoe-only", "roi-source"],
default="shoe-only",
help="shoe-only 输出单类 shoe; roi-source 保留 person + shoe 供 ROI 构建使用",
)
args = parser.parse_args()
source_dir = Path(args.source)
output_dir = Path(args.output)
if args.mode == "roi-source" and args.output == parser.get_default("output"):
output_dir = Path("datasets/ppe-person-shoes")
ensure_clean_dir(output_dir)
total_images = 0
total_boxes = 0
for split in ("train", "val", "test"):
kept_images, kept_boxes = convert_split(source_dir, output_dir, split, args.mode)
total_images += kept_images
total_boxes += kept_boxes
print(f"[{split}] images={kept_images} boxes={kept_boxes}")
if args.mode == "roi-source":
write_roi_source_yaml(output_dir)
else:
write_shoe_yaml(output_dir)
print(f"\n输出目录: {output_dir}")
print(f"总图片数: {total_images}")
print(f"总框数: {total_boxes}")
if __name__ == "__main__":
main()