diff --git a/engine/task_loader.py b/engine/task_loader.py index ecec106..2781f26 100644 --- a/engine/task_loader.py +++ b/engine/task_loader.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import Any import yaml @@ -23,87 +24,83 @@ class TaskValidationError(ValueError): pass -def _require_mapping(value: object, path: str) -> dict[str, object]: +def _require_mapping(value: Any, path: str) -> dict[str, Any]: if not isinstance(value, dict): raise TaskValidationError(f"{path} must be a mapping") return value -def _require_list(value: object, path: str) -> list[object]: +def _require_list(value: Any, path: str) -> list[Any]: if not isinstance(value, list): raise TaskValidationError(f"{path} must be a list") return value -def _require_value(mapping: dict[str, object], key: str, path: str) -> object: +def _require_value(mapping: dict[str, Any], key: str) -> Any: if key not in mapping: - raise TaskValidationError(f"Missing required field: {path}.{key}") + raise TaskValidationError(f"missing required field: {key}") return mapping[key] -def _require_str(mapping: dict[str, object], key: str, path: str) -> str: - value = _require_value(mapping, key, path) - if not isinstance(value, str): - raise TaskValidationError(f"{path}.{key} must be a string") - return value - - -def _require_int(mapping: dict[str, object], key: str, path: str) -> int: - value = _require_value(mapping, key, path) - if not isinstance(value, int) or isinstance(value, bool): - raise TaskValidationError(f"{path}.{key} must be an integer") - return value - - -def _require_str_list(mapping: dict[str, object], key: str, path: str) -> list[str]: - items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") - result: list[str] = [] - for index, item in enumerate(items): - if not isinstance(item, str): - raise TaskValidationError(f"{path}.{key}[{index}] must be a string") - result.append(item) - return result - - -def _require_tie_breakers(mapping: dict[str, object], key: str, path: str) -> list[dict[str, str]]: - items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") - result: list[dict[str, str]] = [] - for index, item in enumerate(items): - entry = _require_mapping(item, f"{path}.{key}[{index}]") - result.append({str(k): str(v) for k, v in entry.items()}) - return result - - def load_task(task_path: Path) -> TaskSpec: + def _require_str(mapping: dict[str, Any], key: str, path: str) -> str: + value = _require_value(mapping, key) + if not isinstance(value, str): + raise TaskValidationError(f"{path}.{key} must be a string") + return value + + def _require_int(mapping: dict[str, Any], key: str, path: str) -> int: + value = _require_value(mapping, key) + if not isinstance(value, int) or isinstance(value, bool): + raise TaskValidationError(f"{path}.{key} must be an integer") + return value + + def _require_str_list(mapping: dict[str, Any], key: str, path: str) -> list[str]: + items = _require_list(_require_value(mapping, key), f"{path}.{key}") + result: list[str] = [] + for index, item in enumerate(items): + if not isinstance(item, str): + raise TaskValidationError(f"{path}.{key}[{index}] must be a string") + result.append(item) + return result + + def _require_tie_breakers(mapping: dict[str, Any], key: str, path: str) -> list[dict[str, str]]: + items = _require_list(_require_value(mapping, key), f"{path}.{key}") + result: list[dict[str, str]] = [] + for index, item in enumerate(items): + entry = _require_mapping(item, f"{path}.{key}[{index}]") + result.append({str(k): str(v) for k, v in entry.items()}) + return result + task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) root = _require_mapping(task_data, "task") task_id = _require_str(root, "id", "task") description = _require_str(root, "description", "task") - artifacts_data = _require_mapping(_require_value(root, "artifacts", "task"), "task.artifacts") + artifacts_data = _require_mapping(_require_value(root, "artifacts"), "task.artifacts") artifacts = ArtifactSpec( include=_require_str_list(artifacts_data, "include", "task.artifacts"), exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"), max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"), ) - mutation_data = _require_mapping(_require_value(root, "mutation", "task"), "task.mutation") + mutation_data = _require_mapping(_require_value(root, "mutation"), "task.mutation") mutation = MutationSpec( mode=_require_str(mutation_data, "mode", "task.mutation"), allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"), max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"), ) - runner_data = _require_mapping(_require_value(root, "runner", "task"), "task.runner") + runner_data = _require_mapping(_require_value(root, "runner"), "task.runner") runner = RunnerSpec( command=_require_str(runner_data, "command", "task.runner"), cwd=_require_str(runner_data, "cwd", "task.runner"), timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"), ) - scorer_data = _require_mapping(_require_value(root, "scorer", "task"), "task.scorer") - scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse", "task.scorer"), "task.scorer.parse") + scorer_data = _require_mapping(_require_value(root, "scorer"), "task.scorer") + scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse"), "task.scorer.parse") scorer = ScorerSpec( type=_require_str(scorer_data, "type", "task.scorer"), command=_require_str(scorer_data, "command", "task.scorer"), @@ -114,7 +111,7 @@ def load_task(task_path: Path) -> TaskSpec: ), ) - objective_data = _require_mapping(_require_value(root, "objective", "task"), "task.objective") + objective_data = _require_mapping(_require_value(root, "objective"), "task.objective") direction = _require_str(objective_data, "direction", "task.objective") if direction not in {"maximize", "minimize"}: raise TaskValidationError("task.objective.direction must be maximize or minimize") @@ -123,7 +120,7 @@ def load_task(task_path: Path) -> TaskSpec: direction=direction, ) - constraints_data = _require_list(_require_value(root, "constraints", "task"), "task.constraints") + constraints_data = _require_list(_require_value(root, "constraints"), "task.constraints") constraints = [] for index, item in enumerate(constraints_data): constraint_data = _require_mapping(item, f"task.constraints[{index}]") @@ -131,24 +128,24 @@ def load_task(task_path: Path) -> TaskSpec: ConstraintSpec( metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"), op=_require_str(constraint_data, "op", f"task.constraints[{index}]"), - value=_require_value(constraint_data, "value", f"task.constraints[{index}]"), + value=_require_value(constraint_data, "value"), ) ) - policy_data = _require_mapping(_require_value(root, "policy", "task"), "task.policy") + policy_data = _require_mapping(_require_value(root, "policy"), "task.policy") policy = PolicySpec( keep_if=_require_str(policy_data, "keep_if", "task.policy"), tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"), on_failure=_require_str(policy_data, "on_failure", "task.policy"), ) - budget_data = _require_mapping(_require_value(root, "budget", "task"), "task.budget") + budget_data = _require_mapping(_require_value(root, "budget"), "task.budget") budget = BudgetSpec( max_iterations=_require_int(budget_data, "max_iterations", "task.budget"), max_failures=_require_int(budget_data, "max_failures", "task.budget"), ) - logging_data = _require_mapping(_require_value(root, "logging", "task"), "task.logging") + logging_data = _require_mapping(_require_value(root, "logging"), "task.logging") logging = LoggingSpec( results_file=_require_str(logging_data, "results_file", "task.logging"), candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"), @@ -167,4 +164,4 @@ def load_task(task_path: Path) -> TaskSpec: budget=budget, logging=logging, root_dir=task_path.parent, - ) + ) \ No newline at end of file