120 lines
3.3 KiB
Python
120 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""从 Construction-PPE 中提取单类 shoe 子集。"""
|
|
|
|
import argparse
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
|
|
SHOE_CLASSES = {"3", "10"} # boots, no_boots
|
|
|
|
|
|
def ensure_clean_dir(path: Path):
|
|
if path.exists():
|
|
shutil.rmtree(path)
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
def write_yaml(output_dir: Path):
|
|
yaml_path = output_dir / "data.yaml"
|
|
abs_output = output_dir.resolve().as_posix()
|
|
yaml_path.write_text(
|
|
"\n".join(
|
|
[
|
|
"# PPE 鞋子单类微调数据集配置",
|
|
"",
|
|
f"path: {abs_output}",
|
|
"train: images/train",
|
|
"val: images/val",
|
|
"test: images/test",
|
|
"",
|
|
"nc: 1",
|
|
"names: ['shoe']",
|
|
"",
|
|
"dataset_info:",
|
|
" name: Construction-PPE shoe subset",
|
|
" source: Construction-PPE",
|
|
" note: 从 boots 和 no_boots 提取并统一映射为 shoe",
|
|
"",
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
|
|
def convert_split(source_dir: Path, output_dir: Path, split: str):
|
|
image_src = source_dir / "images" / split
|
|
label_src = source_dir / "labels" / split
|
|
image_dst = output_dir / "images" / split
|
|
label_dst = output_dir / "labels" / split
|
|
|
|
image_dst.mkdir(parents=True, exist_ok=True)
|
|
label_dst.mkdir(parents=True, exist_ok=True)
|
|
|
|
kept_images = 0
|
|
kept_boxes = 0
|
|
|
|
for label_file in sorted(label_src.glob("*.txt")):
|
|
lines = label_file.read_text(encoding="utf-8").splitlines()
|
|
shoe_lines = []
|
|
|
|
for line in lines:
|
|
parts = line.strip().split()
|
|
if len(parts) < 5 or parts[0] not in SHOE_CLASSES:
|
|
continue
|
|
parts[0] = "0"
|
|
shoe_lines.append(" ".join(parts))
|
|
|
|
if not shoe_lines:
|
|
continue
|
|
|
|
image_file = image_src / f"{label_file.stem}.jpg"
|
|
if not image_file.exists():
|
|
image_file = image_src / f"{label_file.stem}.png"
|
|
if not image_file.exists():
|
|
continue
|
|
|
|
shutil.copy2(image_file, image_dst / image_file.name)
|
|
(label_dst / label_file.name).write_text("\n".join(shoe_lines) + "\n", encoding="utf-8")
|
|
kept_images += 1
|
|
kept_boxes += len(shoe_lines)
|
|
|
|
return kept_images, kept_boxes
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="提取 PPE 鞋子单类子集")
|
|
parser.add_argument(
|
|
"--source",
|
|
default="datasets/construction-ppe",
|
|
help="Construction-PPE 数据集目录",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default="datasets/ppe-shoes",
|
|
help="输出目录",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
source_dir = Path(args.source)
|
|
output_dir = Path(args.output)
|
|
|
|
ensure_clean_dir(output_dir)
|
|
|
|
total_images = 0
|
|
total_boxes = 0
|
|
for split in ("train", "val", "test"):
|
|
kept_images, kept_boxes = convert_split(source_dir, output_dir, split)
|
|
total_images += kept_images
|
|
total_boxes += kept_boxes
|
|
print(f"[{split}] images={kept_images} boxes={kept_boxes}")
|
|
|
|
write_yaml(output_dir)
|
|
print(f"\n输出目录: {output_dir}")
|
|
print(f"总图片数: {total_images}")
|
|
print(f"总框数: {total_boxes}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|