from __future__ import annotations from pathlib import Path from typing import Any import yaml from engine.models import ( ArtifactSpec, BudgetSpec, ConstraintSpec, LoggingSpec, MutationSpec, ObjectiveSpec, PolicySpec, RunnerSpec, ScorerParseSpec, ScorerSpec, TaskSpec, ) class TaskValidationError(ValueError): pass def _require_mapping(value: Any, path: str) -> dict[str, Any]: if not isinstance(value, dict): raise TaskValidationError(f"{path} must be a mapping") return value def _require_list(value: Any, path: str) -> list[Any]: if not isinstance(value, list): raise TaskValidationError(f"{path} must be a list") return value def _require_value(mapping: dict[str, Any], key: str) -> Any: if key not in mapping: raise TaskValidationError(f"missing required field: {key}") return mapping[key] def load_task(task_path: Path) -> TaskSpec: def _require_str(mapping: dict[str, Any], key: str, path: str) -> str: value = _require_value(mapping, key) if not isinstance(value, str): raise TaskValidationError(f"{path}.{key} must be a string") return value def _require_int(mapping: dict[str, Any], key: str, path: str) -> int: value = _require_value(mapping, key) if not isinstance(value, int) or isinstance(value, bool): raise TaskValidationError(f"{path}.{key} must be an integer") return value def _require_str_list(mapping: dict[str, Any], key: str, path: str) -> list[str]: items = _require_list(_require_value(mapping, key), f"{path}.{key}") result: list[str] = [] for index, item in enumerate(items): if not isinstance(item, str): raise TaskValidationError(f"{path}.{key}[{index}] must be a string") result.append(item) return result def _require_tie_breakers(mapping: dict[str, Any], key: str, path: str) -> list[dict[str, str]]: items = _require_list(_require_value(mapping, key), f"{path}.{key}") result: list[dict[str, str]] = [] for index, item in enumerate(items): entry = _require_mapping(item, f"{path}.{key}[{index}]") result.append({str(k): str(v) for k, v in entry.items()}) return result task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) root = _require_mapping(task_data, "task") task_id = _require_str(root, "id", "task") description = _require_str(root, "description", "task") artifacts_data = _require_mapping(_require_value(root, "artifacts"), "task.artifacts") artifacts = ArtifactSpec( include=_require_str_list(artifacts_data, "include", "task.artifacts"), exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"), max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"), ) mutation_data = _require_mapping(_require_value(root, "mutation"), "task.mutation") mutation = MutationSpec( mode=_require_str(mutation_data, "mode", "task.mutation"), allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"), max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"), ) runner_data = _require_mapping(_require_value(root, "runner"), "task.runner") runner = RunnerSpec( command=_require_str(runner_data, "command", "task.runner"), cwd=_require_str(runner_data, "cwd", "task.runner"), timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"), ) scorer_data = _require_mapping(_require_value(root, "scorer"), "task.scorer") scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse"), "task.scorer.parse") scorer = ScorerSpec( type=_require_str(scorer_data, "type", "task.scorer"), command=_require_str(scorer_data, "command", "task.scorer"), parse=ScorerParseSpec( format=_require_str(scorer_parse_data, "format", "task.scorer.parse"), score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"), metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"), ), ) objective_data = _require_mapping(_require_value(root, "objective"), "task.objective") direction = _require_str(objective_data, "direction", "task.objective") if direction not in {"maximize", "minimize"}: raise TaskValidationError("task.objective.direction must be maximize or minimize") objective = ObjectiveSpec( primary_metric=_require_str(objective_data, "primary_metric", "task.objective"), direction=direction, ) constraints_data = _require_list(_require_value(root, "constraints"), "task.constraints") constraints = [] for index, item in enumerate(constraints_data): constraint_data = _require_mapping(item, f"task.constraints[{index}]") constraints.append( ConstraintSpec( metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"), op=_require_str(constraint_data, "op", f"task.constraints[{index}]"), value=_require_value(constraint_data, "value"), ) ) policy_data = _require_mapping(_require_value(root, "policy"), "task.policy") policy = PolicySpec( keep_if=_require_str(policy_data, "keep_if", "task.policy"), tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"), on_failure=_require_str(policy_data, "on_failure", "task.policy"), ) budget_data = _require_mapping(_require_value(root, "budget"), "task.budget") budget = BudgetSpec( max_iterations=_require_int(budget_data, "max_iterations", "task.budget"), max_failures=_require_int(budget_data, "max_failures", "task.budget"), ) logging_data = _require_mapping(_require_value(root, "logging"), "task.logging") logging = LoggingSpec( results_file=_require_str(logging_data, "results_file", "task.logging"), candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"), ) return TaskSpec( id=task_id, description=description, artifacts=artifacts, mutation=mutation, runner=runner, scorer=scorer, objective=objective, constraints=constraints, policy=policy, budget=budget, logging=logging, root_dir=task_path.parent, )