diff --git a/engine/task_loader.py b/engine/task_loader.py index 2781f26..5bc5145 100644 --- a/engine/task_loader.py +++ b/engine/task_loader.py @@ -72,7 +72,10 @@ def load_task(task_path: Path) -> TaskSpec: result.append({str(k): str(v) for k, v in entry.items()}) return result - task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) + try: + task_data = yaml.safe_load(task_path.read_text(encoding="utf-8")) + except yaml.YAMLError as exc: + raise TaskValidationError(str(exc)) from exc root = _require_mapping(task_data, "task") task_id = _require_str(root, "id", "task") @@ -86,8 +89,11 @@ def load_task(task_path: Path) -> TaskSpec: ) mutation_data = _require_mapping(_require_value(root, "mutation"), "task.mutation") + mutation_mode = _require_str(mutation_data, "mode", "task.mutation") + if mutation_mode != "direct_edit": + raise TaskValidationError("task.mutation.mode must be direct_edit") mutation = MutationSpec( - mode=_require_str(mutation_data, "mode", "task.mutation"), + mode=mutation_mode, allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"), max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"), ) @@ -101,11 +107,17 @@ def load_task(task_path: Path) -> TaskSpec: scorer_data = _require_mapping(_require_value(root, "scorer"), "task.scorer") scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse"), "task.scorer.parse") + scorer_type = _require_str(scorer_data, "type", "task.scorer") + if scorer_type != "command": + raise TaskValidationError("task.scorer.type must be command") + scorer_format = _require_str(scorer_parse_data, "format", "task.scorer.parse") + if scorer_format != "json": + raise TaskValidationError("task.scorer.parse.format must be json") scorer = ScorerSpec( - type=_require_str(scorer_data, "type", "task.scorer"), + type=scorer_type, command=_require_str(scorer_data, "command", "task.scorer"), parse=ScorerParseSpec( - format=_require_str(scorer_parse_data, "format", "task.scorer.parse"), + format=scorer_format, score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"), metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"), ), @@ -124,19 +136,28 @@ def load_task(task_path: Path) -> TaskSpec: constraints = [] for index, item in enumerate(constraints_data): constraint_data = _require_mapping(item, f"task.constraints[{index}]") + constraint_op = _require_str(constraint_data, "op", f"task.constraints[{index}]") + if constraint_op not in {"<=", ">=", "=="}: + raise TaskValidationError(f"task.constraints[{index}].op must be one of <=, >=, ==") constraints.append( ConstraintSpec( metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"), - op=_require_str(constraint_data, "op", f"task.constraints[{index}]"), + op=constraint_op, value=_require_value(constraint_data, "value"), ) ) policy_data = _require_mapping(_require_value(root, "policy"), "task.policy") + policy_keep_if = _require_str(policy_data, "keep_if", "task.policy") + if policy_keep_if != "better_primary": + raise TaskValidationError("task.policy.keep_if must be better_primary") + policy_on_failure = _require_str(policy_data, "on_failure", "task.policy") + if policy_on_failure != "discard": + raise TaskValidationError("task.policy.on_failure must be discard") policy = PolicySpec( - keep_if=_require_str(policy_data, "keep_if", "task.policy"), + keep_if=policy_keep_if, tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"), - on_failure=_require_str(policy_data, "on_failure", "task.policy"), + on_failure=policy_on_failure, ) budget_data = _require_mapping(_require_value(root, "budget"), "task.budget") @@ -164,4 +185,4 @@ def load_task(task_path: Path) -> TaskSpec: budget=budget, logging=logging, root_dir=task_path.parent, - ) \ No newline at end of file + ) diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 0bdf1e6..fa698ec 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -74,6 +74,17 @@ class TaskLoaderTest(unittest.TestCase): load_task(self.write_task(content)) self.assertIn("direction", str(ctx.exception)) + def test_rejects_malformed_yaml(self) -> None: + content = VALID_TASK + " bad_indent: [\n" + with self.assertRaises(TaskValidationError): + load_task(self.write_task(content)) + + def test_rejects_invalid_enum_value(self) -> None: + content = VALID_TASK.replace("mode: direct_edit", "mode: patch") + with self.assertRaises(TaskValidationError) as ctx: + load_task(self.write_task(content)) + self.assertIn("mutation.mode", str(ctx.exception)) + if __name__ == "__main__": unittest.main()