189 lines
7.5 KiB
Python
189 lines
7.5 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
from engine.models import (
|
|
ArtifactSpec,
|
|
BudgetSpec,
|
|
ConstraintSpec,
|
|
LoggingSpec,
|
|
MutationSpec,
|
|
ObjectiveSpec,
|
|
PolicySpec,
|
|
RunnerSpec,
|
|
ScorerParseSpec,
|
|
ScorerSpec,
|
|
TaskSpec,
|
|
)
|
|
|
|
|
|
class TaskValidationError(ValueError):
|
|
pass
|
|
|
|
|
|
def _require_mapping(value: Any, path: str) -> dict[str, Any]:
|
|
if not isinstance(value, dict):
|
|
raise TaskValidationError(f"{path} must be a mapping")
|
|
return value
|
|
|
|
|
|
def _require_list(value: Any, path: str) -> list[Any]:
|
|
if not isinstance(value, list):
|
|
raise TaskValidationError(f"{path} must be a list")
|
|
return value
|
|
|
|
|
|
def _require_value(mapping: dict[str, Any], key: str) -> Any:
|
|
if key not in mapping:
|
|
raise TaskValidationError(f"missing required field: {key}")
|
|
return mapping[key]
|
|
|
|
|
|
def load_task(task_path: Path) -> TaskSpec:
|
|
def _require_str(mapping: dict[str, Any], key: str, path: str) -> str:
|
|
value = _require_value(mapping, key)
|
|
if not isinstance(value, str):
|
|
raise TaskValidationError(f"{path}.{key} must be a string")
|
|
return value
|
|
|
|
def _require_int(mapping: dict[str, Any], key: str, path: str) -> int:
|
|
value = _require_value(mapping, key)
|
|
if not isinstance(value, int) or isinstance(value, bool):
|
|
raise TaskValidationError(f"{path}.{key} must be an integer")
|
|
return value
|
|
|
|
def _require_str_list(mapping: dict[str, Any], key: str, path: str) -> list[str]:
|
|
items = _require_list(_require_value(mapping, key), f"{path}.{key}")
|
|
result: list[str] = []
|
|
for index, item in enumerate(items):
|
|
if not isinstance(item, str):
|
|
raise TaskValidationError(f"{path}.{key}[{index}] must be a string")
|
|
result.append(item)
|
|
return result
|
|
|
|
def _require_tie_breakers(mapping: dict[str, Any], key: str, path: str) -> list[dict[str, str]]:
|
|
items = _require_list(_require_value(mapping, key), f"{path}.{key}")
|
|
result: list[dict[str, str]] = []
|
|
for index, item in enumerate(items):
|
|
entry = _require_mapping(item, f"{path}.{key}[{index}]")
|
|
result.append({str(k): str(v) for k, v in entry.items()})
|
|
return result
|
|
|
|
try:
|
|
task_data = yaml.safe_load(task_path.read_text(encoding="utf-8"))
|
|
except yaml.YAMLError as exc:
|
|
raise TaskValidationError(str(exc)) from exc
|
|
root = _require_mapping(task_data, "task")
|
|
|
|
task_id = _require_str(root, "id", "task")
|
|
description = _require_str(root, "description", "task")
|
|
|
|
artifacts_data = _require_mapping(_require_value(root, "artifacts"), "task.artifacts")
|
|
artifacts = ArtifactSpec(
|
|
include=_require_str_list(artifacts_data, "include", "task.artifacts"),
|
|
exclude=_require_str_list(artifacts_data, "exclude", "task.artifacts"),
|
|
max_files_per_iteration=_require_int(artifacts_data, "max_files_per_iteration", "task.artifacts"),
|
|
)
|
|
|
|
mutation_data = _require_mapping(_require_value(root, "mutation"), "task.mutation")
|
|
mutation_mode = _require_str(mutation_data, "mode", "task.mutation")
|
|
if mutation_mode != "direct_edit":
|
|
raise TaskValidationError("task.mutation.mode must be direct_edit")
|
|
mutation = MutationSpec(
|
|
mode=mutation_mode,
|
|
allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"),
|
|
max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"),
|
|
)
|
|
|
|
runner_data = _require_mapping(_require_value(root, "runner"), "task.runner")
|
|
runner = RunnerSpec(
|
|
command=_require_str(runner_data, "command", "task.runner"),
|
|
cwd=_require_str(runner_data, "cwd", "task.runner"),
|
|
timeout_seconds=_require_int(runner_data, "timeout_seconds", "task.runner"),
|
|
)
|
|
|
|
scorer_data = _require_mapping(_require_value(root, "scorer"), "task.scorer")
|
|
scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse"), "task.scorer.parse")
|
|
scorer_type = _require_str(scorer_data, "type", "task.scorer")
|
|
if scorer_type != "command":
|
|
raise TaskValidationError("task.scorer.type must be command")
|
|
scorer_format = _require_str(scorer_parse_data, "format", "task.scorer.parse")
|
|
if scorer_format != "json":
|
|
raise TaskValidationError("task.scorer.parse.format must be json")
|
|
scorer = ScorerSpec(
|
|
type=scorer_type,
|
|
command=_require_str(scorer_data, "command", "task.scorer"),
|
|
parse=ScorerParseSpec(
|
|
format=scorer_format,
|
|
score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"),
|
|
metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"),
|
|
),
|
|
)
|
|
|
|
objective_data = _require_mapping(_require_value(root, "objective"), "task.objective")
|
|
direction = _require_str(objective_data, "direction", "task.objective")
|
|
if direction not in {"maximize", "minimize"}:
|
|
raise TaskValidationError("task.objective.direction must be maximize or minimize")
|
|
objective = ObjectiveSpec(
|
|
primary_metric=_require_str(objective_data, "primary_metric", "task.objective"),
|
|
direction=direction,
|
|
)
|
|
|
|
constraints_data = _require_list(_require_value(root, "constraints"), "task.constraints")
|
|
constraints = []
|
|
for index, item in enumerate(constraints_data):
|
|
constraint_data = _require_mapping(item, f"task.constraints[{index}]")
|
|
constraint_op = _require_str(constraint_data, "op", f"task.constraints[{index}]")
|
|
if constraint_op not in {"<=", ">=", "=="}:
|
|
raise TaskValidationError(f"task.constraints[{index}].op must be one of <=, >=, ==")
|
|
constraints.append(
|
|
ConstraintSpec(
|
|
metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"),
|
|
op=constraint_op,
|
|
value=_require_value(constraint_data, "value"),
|
|
)
|
|
)
|
|
|
|
policy_data = _require_mapping(_require_value(root, "policy"), "task.policy")
|
|
policy_keep_if = _require_str(policy_data, "keep_if", "task.policy")
|
|
if policy_keep_if != "better_primary":
|
|
raise TaskValidationError("task.policy.keep_if must be better_primary")
|
|
policy_on_failure = _require_str(policy_data, "on_failure", "task.policy")
|
|
if policy_on_failure != "discard":
|
|
raise TaskValidationError("task.policy.on_failure must be discard")
|
|
policy = PolicySpec(
|
|
keep_if=policy_keep_if,
|
|
tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"),
|
|
on_failure=policy_on_failure,
|
|
)
|
|
|
|
budget_data = _require_mapping(_require_value(root, "budget"), "task.budget")
|
|
budget = BudgetSpec(
|
|
max_iterations=_require_int(budget_data, "max_iterations", "task.budget"),
|
|
max_failures=_require_int(budget_data, "max_failures", "task.budget"),
|
|
)
|
|
|
|
logging_data = _require_mapping(_require_value(root, "logging"), "task.logging")
|
|
logging = LoggingSpec(
|
|
results_file=_require_str(logging_data, "results_file", "task.logging"),
|
|
candidate_dir=_require_str(logging_data, "candidate_dir", "task.logging"),
|
|
)
|
|
|
|
return TaskSpec(
|
|
id=task_id,
|
|
description=description,
|
|
artifacts=artifacts,
|
|
mutation=mutation,
|
|
runner=runner,
|
|
scorer=scorer,
|
|
objective=objective,
|
|
constraints=constraints,
|
|
policy=policy,
|
|
budget=budget,
|
|
logging=logging,
|
|
root_dir=task_path.parent,
|
|
)
|