49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Iterable, List, Tuple
|
|
|
|
|
|
_IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PersonSamples:
|
|
name: str
|
|
image_paths: List[str]
|
|
|
|
|
|
class DatasetScanner:
|
|
def __init__(self, dataset_root: str) -> None:
|
|
self.dataset_root = os.path.abspath(dataset_root)
|
|
|
|
def iter_persons(self) -> Iterable[PersonSamples]:
|
|
if not os.path.isdir(self.dataset_root):
|
|
raise FileNotFoundError(f"dataset root not found: {self.dataset_root}")
|
|
|
|
with os.scandir(self.dataset_root) as it:
|
|
person_dirs = [e for e in it if e.is_dir()]
|
|
|
|
person_dirs.sort(key=lambda e: e.name)
|
|
for e in person_dirs:
|
|
name = e.name
|
|
img_paths: List[str] = []
|
|
with os.scandir(e.path) as it2:
|
|
for f in it2:
|
|
if not f.is_file():
|
|
continue
|
|
ext = os.path.splitext(f.name)[1].lower()
|
|
if ext in _IMG_EXTS:
|
|
img_paths.append(os.path.abspath(f.path))
|
|
img_paths.sort()
|
|
yield PersonSamples(name=name, image_paths=img_paths)
|
|
|
|
def summary(self) -> Tuple[int, int]:
|
|
persons = 0
|
|
images = 0
|
|
for p in self.iter_persons():
|
|
persons += 1
|
|
images += len(p.image_paths)
|
|
return persons, images
|