from __future__ import annotations from pathlib import Path import yaml from engine.models import ( ArtifactSpec, BudgetSpec, ConstraintSpec, LoggingSpec, MutationSpec, ObjectiveSpec, PolicySpec, RunnerSpec, ScorerParseSpec, ScorerSpec, TaskSpec, ) class TaskValidationError(ValueError): pass def _require_mapping(value: object, path: str) -> dict[str, object]: if not isinstance(value, dict): raise TaskValidationError(f"{path} must be a mapping") return value def _require_list(value: object, path: str) -> list[object]: if not isinstance(value, list): raise TaskValidationError(f"{path} must be a list") return value def _require_value(mapping: dict[str, object], key: str, path: str) -> object: if key not in mapping: raise TaskValidationError(f"Missing required field: {path}.{key}") return mapping[key] def _require_str(mapping: dict[str, object], key: str, path: str) -> str: value = _require_value(mapping, key, path) if not isinstance(value, str): raise TaskValidationError(f"{path}.{key} must be a string") return value def _require_int(mapping: dict[str, object], key: str, path: str) -> int: value = _require_value(mapping, key, path) if not isinstance(value, int) or isinstance(value, bool): raise TaskValidationError(f"{path}.{key} must be an integer") return value def _require_str_list(mapping: dict[str, object], key: str, path: str) -> list[str]: items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") result: list[str] = [] for index, item in enumerate(items): if not isinstance(item, str): raise TaskValidationError(f"{path}.{key}[{index}] must be a string") result.append(item) return result def _require_tie_breakers(mapping: dict[str, object], key: str, path: str) -> list[dict[str, str]]: items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") result: list[dict[str, str]] = [] for index, item in enumerate(items): entry = _require_mapping(item, f"{path}.{key}[{index}]") result.append({str(k): str(v) for k, v in entry.items()}) return result def load_task(task_path: Path) -> TaskSpec: task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) root = _require_mapping(task_data, "task") task_id = _require_str(root, "id", "task") description = _require_str(root, "description", "task") artifacts_data = _require_mapping(_require_value(root, "artifacts", "task"), "task.artifacts") artifacts = ArtifactSpec( include=_require_str_list(artifacts_data, "include", "task.artifacts"), exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"), max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"), ) mutation_data = _require_mapping(_require_value(root, "mutation", "task"), "task.mutation") mutation = MutationSpec( mode=_require_str(mutation_data, "mode", "task.mutation"), allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"), max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"), ) runner_data = _require_mapping(_require_value(root, "runner", "task"), "task.runner") runner = RunnerSpec( command=_require_str(runner_data, "command", "task.runner"), cwd=_require_str(runner_data, "cwd", "task.runner"), timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"), ) scorer_data = _require_mapping(_require_value(root, "scorer", "task"), "task.scorer") scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse", "task.scorer"), "task.scorer.parse") scorer = ScorerSpec( type=_require_str(scorer_data, "type", "task.scorer"), command=_require_str(scorer_data, "command", "task.scorer"), parse=ScorerParseSpec( format=_require_str(scorer_parse_data, "format", "task.scorer.parse"), score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"), metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"), ), ) objective_data = _require_mapping(_require_value(root, "objective", "task"), "task.objective") direction = _require_str(objective_data, "direction", "task.objective") if direction not in {"maximize", "minimize"}: raise TaskValidationError("task.objective.direction must be maximize or minimize") objective = ObjectiveSpec( primary_metric=_require_str(objective_data, "primary_metric", "task.objective"), direction=direction, ) constraints_data = _require_list(_require_value(root, "constraints", "task"), "task.constraints") constraints = [] for index, item in enumerate(constraints_data): constraint_data = _require_mapping(item, f"task.constraints[{index}]") constraints.append( ConstraintSpec( metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"), op=_require_str(constraint_data, "op", f"task.constraints[{index}]"), value=_require_value(constraint_data, "value", f"task.constraints[{index}]"), ) ) policy_data = _require_mapping(_require_value(root, "policy", "task"), "task.policy") policy = PolicySpec( keep_if=_require_str(policy_data, "keep_if", "task.policy"), tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"), on_failure=_require_str(policy_data, "on_failure", "task.policy"), ) budget_data = _require_mapping(_require_value(root, "budget", "task"), "task.budget") budget = BudgetSpec( max_iterations=_require_int(budget_data, "max_iterations", "task.budget"), max_failures=_require_int(budget_data, "max_failures", "task.budget"), ) logging_data = _require_mapping(_require_value(root, "logging", "task"), "task.logging") logging = LoggingSpec( results_file=_require_str(logging_data, "results_file", "task.logging"), candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"), ) return TaskSpec( id=task_id, description=description, artifacts=artifacts, mutation=mutation, runner=runner, scorer=scorer, objective=objective, constraints=constraints, policy=policy, budget=budget, logging=logging, root_dir=task_path.parent, )