diff --git a/engine/task_loader.py b/engine/task_loader.py new file mode 100644 index 0000000..ecec106 --- /dev/null +++ b/engine/task_loader.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml + +from engine.models import ( + ArtifactSpec, + BudgetSpec, + ConstraintSpec, + LoggingSpec, + MutationSpec, + ObjectiveSpec, + PolicySpec, + RunnerSpec, + ScorerParseSpec, + ScorerSpec, + TaskSpec, +) + + +class TaskValidationError(ValueError): + pass + + +def _require_mapping(value: object, path: str) -> dict[str, object]: + if not isinstance(value, dict): + raise TaskValidationError(f"{path} must be a mapping") + return value + + +def _require_list(value: object, path: str) -> list[object]: + if not isinstance(value, list): + raise TaskValidationError(f"{path} must be a list") + return value + + +def _require_value(mapping: dict[str, object], key: str, path: str) -> object: + if key not in mapping: + raise TaskValidationError(f"Missing required field: {path}.{key}") + return mapping[key] + + +def _require_str(mapping: dict[str, object], key: str, path: str) -> str: + value = _require_value(mapping, key, path) + if not isinstance(value, str): + raise TaskValidationError(f"{path}.{key} must be a string") + return value + + +def _require_int(mapping: dict[str, object], key: str, path: str) -> int: + value = _require_value(mapping, key, path) + if not isinstance(value, int) or isinstance(value, bool): + raise TaskValidationError(f"{path}.{key} must be an integer") + return value + + +def _require_str_list(mapping: dict[str, object], key: str, path: str) -> list[str]: + items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") + result: list[str] = [] + for index, item in enumerate(items): + if not isinstance(item, str): + raise TaskValidationError(f"{path}.{key}[{index}] must be a string") + result.append(item) + return result + + +def _require_tie_breakers(mapping: dict[str, object], key: str, path: str) -> list[dict[str, str]]: + items = _require_list(_require_value(mapping, key, path), f"{path}.{key}") + result: list[dict[str, str]] = [] + for index, item in enumerate(items): + entry = _require_mapping(item, f"{path}.{key}[{index}]") + result.append({str(k): str(v) for k, v in entry.items()}) + return result + + +def load_task(task_path: Path) -> TaskSpec: + task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) + root = _require_mapping(task_data, "task") + + task_id = _require_str(root, "id", "task") + description = _require_str(root, "description", "task") + + artifacts_data = _require_mapping(_require_value(root, "artifacts", "task"), "task.artifacts") + artifacts = ArtifactSpec( + include=_require_str_list(artifacts_data, "include", "task.artifacts"), + exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"), + max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"), + ) + + mutation_data = _require_mapping(_require_value(root, "mutation", "task"), "task.mutation") + mutation = MutationSpec( + mode=_require_str(mutation_data, "mode", "task.mutation"), + allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"), + max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"), + ) + + runner_data = _require_mapping(_require_value(root, "runner", "task"), "task.runner") + runner = RunnerSpec( + command=_require_str(runner_data, "command", "task.runner"), + cwd=_require_str(runner_data, "cwd", "task.runner"), + timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"), + ) + + scorer_data = _require_mapping(_require_value(root, "scorer", "task"), "task.scorer") + scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse", "task.scorer"), "task.scorer.parse") + scorer = ScorerSpec( + type=_require_str(scorer_data, "type", "task.scorer"), + command=_require_str(scorer_data, "command", "task.scorer"), + parse=ScorerParseSpec( + format=_require_str(scorer_parse_data, "format", "task.scorer.parse"), + score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"), + metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"), + ), + ) + + objective_data = _require_mapping(_require_value(root, "objective", "task"), "task.objective") + direction = _require_str(objective_data, "direction", "task.objective") + if direction not in {"maximize", "minimize"}: + raise TaskValidationError("task.objective.direction must be maximize or minimize") + objective = ObjectiveSpec( + primary_metric=_require_str(objective_data, "primary_metric", "task.objective"), + direction=direction, + ) + + constraints_data = _require_list(_require_value(root, "constraints", "task"), "task.constraints") + constraints = [] + for index, item in enumerate(constraints_data): + constraint_data = _require_mapping(item, f"task.constraints[{index}]") + constraints.append( + ConstraintSpec( + metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"), + op=_require_str(constraint_data, "op", f"task.constraints[{index}]"), + value=_require_value(constraint_data, "value", f"task.constraints[{index}]"), + ) + ) + + policy_data = _require_mapping(_require_value(root, "policy", "task"), "task.policy") + policy = PolicySpec( + keep_if=_require_str(policy_data, "keep_if", "task.policy"), + tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"), + on_failure=_require_str(policy_data, "on_failure", "task.policy"), + ) + + budget_data = _require_mapping(_require_value(root, "budget", "task"), "task.budget") + budget = BudgetSpec( + max_iterations=_require_int(budget_data, "max_iterations", "task.budget"), + max_failures=_require_int(budget_data, "max_failures", "task.budget"), + ) + + logging_data = _require_mapping(_require_value(root, "logging", "task"), "task.logging") + logging = LoggingSpec( + results_file=_require_str(logging_data, "results_file", "task.logging"), + candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"), + ) + + return TaskSpec( + id=task_id, + description=description, + artifacts=artifacts, + mutation=mutation, + runner=runner, + scorer=scorer, + objective=objective, + constraints=constraints, + policy=policy, + budget=budget, + logging=logging, + root_dir=task_path.parent, + ) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 1aa6756..0bdf1e6 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -2,12 +2,10 @@ from pathlib import Path import tempfile import unittest -from engine.task_loader import load_task +from engine.task_loader import TaskValidationError, load_task -class TaskLoaderSmokeTest(unittest.TestCase): - def test_loads_minimal_task(self) -> None: - task_yaml = """ +VALID_TASK = """ id: demo description: Demo task artifacts: @@ -48,12 +46,33 @@ logging: results_file: work/results.jsonl candidate_dir: work/candidates """ - with tempfile.TemporaryDirectory() as tmp: - task_path = Path(tmp) / "task.yaml" - task_path.write_text(task_yaml, encoding="utf-8") - task = load_task(task_path) + + +class TaskLoaderTest(unittest.TestCase): + def write_task(self, content: str) -> Path: + temp_dir = tempfile.TemporaryDirectory() + self.addCleanup(temp_dir.cleanup) + task_path = Path(temp_dir.name) / "task.yaml" + task_path.write_text(content, encoding="utf-8") + return task_path + + def test_loads_minimal_task(self) -> None: + task = load_task(self.write_task(VALID_TASK)) self.assertEqual(task.id, "demo") - self.assertEqual(task.objective.direction, "maximize") + self.assertEqual(task.artifacts.max_files_per_iteration, 1) + self.assertEqual(task.constraints[0].metric, "violation_count") + + def test_rejects_missing_required_section(self) -> None: + content = VALID_TASK.replace("objective:\n primary_metric: score\n direction: maximize\n", "") + with self.assertRaises(TaskValidationError) as ctx: + load_task(self.write_task(content)) + self.assertIn("objective", str(ctx.exception)) + + def test_rejects_invalid_direction(self) -> None: + content = VALID_TASK.replace("direction: maximize", "direction: sideways") + with self.assertRaises(TaskValidationError) as ctx: + load_task(self.write_task(content)) + self.assertIn("direction", str(ctx.exception)) if __name__ == "__main__":