feat: add yaml task loader
This commit is contained in:
parent
61b635a3e8
commit
79f1e88ba0
170
engine/task_loader.py
Normal file
170
engine/task_loader.py
Normal file
@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from engine.models import (
|
||||
ArtifactSpec,
|
||||
BudgetSpec,
|
||||
ConstraintSpec,
|
||||
LoggingSpec,
|
||||
MutationSpec,
|
||||
ObjectiveSpec,
|
||||
PolicySpec,
|
||||
RunnerSpec,
|
||||
ScorerParseSpec,
|
||||
ScorerSpec,
|
||||
TaskSpec,
|
||||
)
|
||||
|
||||
|
||||
class TaskValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def _require_mapping(value: object, path: str) -> dict[str, object]:
|
||||
if not isinstance(value, dict):
|
||||
raise TaskValidationError(f"{path} must be a mapping")
|
||||
return value
|
||||
|
||||
|
||||
def _require_list(value: object, path: str) -> list[object]:
|
||||
if not isinstance(value, list):
|
||||
raise TaskValidationError(f"{path} must be a list")
|
||||
return value
|
||||
|
||||
|
||||
def _require_value(mapping: dict[str, object], key: str, path: str) -> object:
|
||||
if key not in mapping:
|
||||
raise TaskValidationError(f"Missing required field: {path}.{key}")
|
||||
return mapping[key]
|
||||
|
||||
|
||||
def _require_str(mapping: dict[str, object], key: str, path: str) -> str:
|
||||
value = _require_value(mapping, key, path)
|
||||
if not isinstance(value, str):
|
||||
raise TaskValidationError(f"{path}.{key} must be a string")
|
||||
return value
|
||||
|
||||
|
||||
def _require_int(mapping: dict[str, object], key: str, path: str) -> int:
|
||||
value = _require_value(mapping, key, path)
|
||||
if not isinstance(value, int) or isinstance(value, bool):
|
||||
raise TaskValidationError(f"{path}.{key} must be an integer")
|
||||
return value
|
||||
|
||||
|
||||
def _require_str_list(mapping: dict[str, object], key: str, path: str) -> list[str]:
|
||||
items = _require_list(_require_value(mapping, key, path), f"{path}.{key}")
|
||||
result: list[str] = []
|
||||
for index, item in enumerate(items):
|
||||
if not isinstance(item, str):
|
||||
raise TaskValidationError(f"{path}.{key}[{index}] must be a string")
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
def _require_tie_breakers(mapping: dict[str, object], key: str, path: str) -> list[dict[str, str]]:
|
||||
items = _require_list(_require_value(mapping, key, path), f"{path}.{key}")
|
||||
result: list[dict[str, str]] = []
|
||||
for index, item in enumerate(items):
|
||||
entry = _require_mapping(item, f"{path}.{key}[{index}]")
|
||||
result.append({str(k): str(v) for k, v in entry.items()})
|
||||
return result
|
||||
|
||||
|
||||
def load_task(task_path: Path) -> TaskSpec:
|
||||
task_data = yaml.safe_load(task_path.read_text(encoding="utf-8"))
|
||||
root = _require_mapping(task_data, "task")
|
||||
|
||||
task_id = _require_str(root, "id", "task")
|
||||
description = _require_str(root, "description", "task")
|
||||
|
||||
artifacts_data = _require_mapping(_require_value(root, "artifacts", "task"), "task.artifacts")
|
||||
artifacts = ArtifactSpec(
|
||||
include=_require_str_list(artifacts_data, "include", "task.artifacts"),
|
||||
exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"),
|
||||
max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"),
|
||||
)
|
||||
|
||||
mutation_data = _require_mapping(_require_value(root, "mutation", "task"), "task.mutation")
|
||||
mutation = MutationSpec(
|
||||
mode=_require_str(mutation_data, "mode", "task.mutation"),
|
||||
allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"),
|
||||
max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"),
|
||||
)
|
||||
|
||||
runner_data = _require_mapping(_require_value(root, "runner", "task"), "task.runner")
|
||||
runner = RunnerSpec(
|
||||
command=_require_str(runner_data, "command", "task.runner"),
|
||||
cwd=_require_str(runner_data, "cwd", "task.runner"),
|
||||
timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"),
|
||||
)
|
||||
|
||||
scorer_data = _require_mapping(_require_value(root, "scorer", "task"), "task.scorer")
|
||||
scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse", "task.scorer"), "task.scorer.parse")
|
||||
scorer = ScorerSpec(
|
||||
type=_require_str(scorer_data, "type", "task.scorer"),
|
||||
command=_require_str(scorer_data, "command", "task.scorer"),
|
||||
parse=ScorerParseSpec(
|
||||
format=_require_str(scorer_parse_data, "format", "task.scorer.parse"),
|
||||
score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"),
|
||||
metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"),
|
||||
),
|
||||
)
|
||||
|
||||
objective_data = _require_mapping(_require_value(root, "objective", "task"), "task.objective")
|
||||
direction = _require_str(objective_data, "direction", "task.objective")
|
||||
if direction not in {"maximize", "minimize"}:
|
||||
raise TaskValidationError("task.objective.direction must be maximize or minimize")
|
||||
objective = ObjectiveSpec(
|
||||
primary_metric=_require_str(objective_data, "primary_metric", "task.objective"),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
constraints_data = _require_list(_require_value(root, "constraints", "task"), "task.constraints")
|
||||
constraints = []
|
||||
for index, item in enumerate(constraints_data):
|
||||
constraint_data = _require_mapping(item, f"task.constraints[{index}]")
|
||||
constraints.append(
|
||||
ConstraintSpec(
|
||||
metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"),
|
||||
op=_require_str(constraint_data, "op", f"task.constraints[{index}]"),
|
||||
value=_require_value(constraint_data, "value", f"task.constraints[{index}]"),
|
||||
)
|
||||
)
|
||||
|
||||
policy_data = _require_mapping(_require_value(root, "policy", "task"), "task.policy")
|
||||
policy = PolicySpec(
|
||||
keep_if=_require_str(policy_data, "keep_if", "task.policy"),
|
||||
tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"),
|
||||
on_failure=_require_str(policy_data, "on_failure", "task.policy"),
|
||||
)
|
||||
|
||||
budget_data = _require_mapping(_require_value(root, "budget", "task"), "task.budget")
|
||||
budget = BudgetSpec(
|
||||
max_iterations=_require_int(budget_data, "max_iterations", "task.budget"),
|
||||
max_failures=_require_int(budget_data, "max_failures", "task.budget"),
|
||||
)
|
||||
|
||||
logging_data = _require_mapping(_require_value(root, "logging", "task"), "task.logging")
|
||||
logging = LoggingSpec(
|
||||
results_file=_require_str(logging_data, "results_file", "task.logging"),
|
||||
candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"),
|
||||
)
|
||||
|
||||
return TaskSpec(
|
||||
id=task_id,
|
||||
description=description,
|
||||
artifacts=artifacts,
|
||||
mutation=mutation,
|
||||
runner=runner,
|
||||
scorer=scorer,
|
||||
objective=objective,
|
||||
constraints=constraints,
|
||||
policy=policy,
|
||||
budget=budget,
|
||||
logging=logging,
|
||||
root_dir=task_path.parent,
|
||||
)
|
||||
@ -2,12 +2,10 @@ from pathlib import Path
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from engine.task_loader import load_task
|
||||
from engine.task_loader import TaskValidationError, load_task
|
||||
|
||||
|
||||
class TaskLoaderSmokeTest(unittest.TestCase):
|
||||
def test_loads_minimal_task(self) -> None:
|
||||
task_yaml = """
|
||||
VALID_TASK = """
|
||||
id: demo
|
||||
description: Demo task
|
||||
artifacts:
|
||||
@ -48,12 +46,33 @@ logging:
|
||||
results_file: work/results.jsonl
|
||||
candidate_dir: work/candidates
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
task_path = Path(tmp) / "task.yaml"
|
||||
task_path.write_text(task_yaml, encoding="utf-8")
|
||||
task = load_task(task_path)
|
||||
|
||||
|
||||
class TaskLoaderTest(unittest.TestCase):
|
||||
def write_task(self, content: str) -> Path:
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
self.addCleanup(temp_dir.cleanup)
|
||||
task_path = Path(temp_dir.name) / "task.yaml"
|
||||
task_path.write_text(content, encoding="utf-8")
|
||||
return task_path
|
||||
|
||||
def test_loads_minimal_task(self) -> None:
|
||||
task = load_task(self.write_task(VALID_TASK))
|
||||
self.assertEqual(task.id, "demo")
|
||||
self.assertEqual(task.objective.direction, "maximize")
|
||||
self.assertEqual(task.artifacts.max_files_per_iteration, 1)
|
||||
self.assertEqual(task.constraints[0].metric, "violation_count")
|
||||
|
||||
def test_rejects_missing_required_section(self) -> None:
|
||||
content = VALID_TASK.replace("objective:\n primary_metric: score\n direction: maximize\n", "")
|
||||
with self.assertRaises(TaskValidationError) as ctx:
|
||||
load_task(self.write_task(content))
|
||||
self.assertIn("objective", str(ctx.exception))
|
||||
|
||||
def test_rejects_invalid_direction(self) -> None:
|
||||
content = VALID_TASK.replace("direction: maximize", "direction: sideways")
|
||||
with self.assertRaises(TaskValidationError) as ctx:
|
||||
load_task(self.write_task(content))
|
||||
self.assertIn("direction", str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user