fix: normalize task loader validation
This commit is contained in:
parent
c3acdb96f7
commit
db3ae7cff1
@ -72,7 +72,10 @@ def load_task(task_path: Path) -> TaskSpec:
|
||||
result.append({str(k): str(v) for k, v in entry.items()})
|
||||
return result
|
||||
|
||||
task_data = yaml.safe_load(task_path.read_text(encoding="utf-8"))
|
||||
try:
|
||||
task_data = yaml.safe_load(task_path.read_text(encoding="utf-8"))
|
||||
except yaml.YAMLError as exc:
|
||||
raise TaskValidationError(str(exc)) from exc
|
||||
root = _require_mapping(task_data, "task")
|
||||
|
||||
task_id = _require_str(root, "id", "task")
|
||||
@ -86,8 +89,11 @@ def load_task(task_path: Path) -> TaskSpec:
|
||||
)
|
||||
|
||||
mutation_data = _require_mapping(_require_value(root, "mutation"), "task.mutation")
|
||||
mutation_mode = _require_str(mutation_data, "mode", "task.mutation")
|
||||
if mutation_mode != "direct_edit":
|
||||
raise TaskValidationError("task.mutation.mode must be direct_edit")
|
||||
mutation = MutationSpec(
|
||||
mode=_require_str(mutation_data, "mode", "task.mutation"),
|
||||
mode=mutation_mode,
|
||||
allowed_file_types=_require_str_list(mutation_data, "allowed_file_types", "task.mutation"),
|
||||
max_changed_lines=_require_int(mutation_data, "max_changed_lines", "task.mutation"),
|
||||
)
|
||||
@ -101,11 +107,17 @@ def load_task(task_path: Path) -> TaskSpec:
|
||||
|
||||
scorer_data = _require_mapping(_require_value(root, "scorer"), "task.scorer")
|
||||
scorer_parse_data = _require_mapping(_require_value(scorer_data, "parse"), "task.scorer.parse")
|
||||
scorer_type = _require_str(scorer_data, "type", "task.scorer")
|
||||
if scorer_type != "command":
|
||||
raise TaskValidationError("task.scorer.type must be command")
|
||||
scorer_format = _require_str(scorer_parse_data, "format", "task.scorer.parse")
|
||||
if scorer_format != "json":
|
||||
raise TaskValidationError("task.scorer.parse.format must be json")
|
||||
scorer = ScorerSpec(
|
||||
type=_require_str(scorer_data, "type", "task.scorer"),
|
||||
type=scorer_type,
|
||||
command=_require_str(scorer_data, "command", "task.scorer"),
|
||||
parse=ScorerParseSpec(
|
||||
format=_require_str(scorer_parse_data, "format", "task.scorer.parse"),
|
||||
format=scorer_format,
|
||||
score_field=_require_str(scorer_parse_data, "score_field", "task.scorer.parse"),
|
||||
metrics_field=_require_str(scorer_parse_data, "metrics_field", "task.scorer.parse"),
|
||||
),
|
||||
@ -124,19 +136,28 @@ def load_task(task_path: Path) -> TaskSpec:
|
||||
constraints = []
|
||||
for index, item in enumerate(constraints_data):
|
||||
constraint_data = _require_mapping(item, f"task.constraints[{index}]")
|
||||
constraint_op = _require_str(constraint_data, "op", f"task.constraints[{index}]")
|
||||
if constraint_op not in {"<=", ">=", "=="}:
|
||||
raise TaskValidationError(f"task.constraints[{index}].op must be one of <=, >=, ==")
|
||||
constraints.append(
|
||||
ConstraintSpec(
|
||||
metric=_require_str(constraint_data, "metric", f"task.constraints[{index}]"),
|
||||
op=_require_str(constraint_data, "op", f"task.constraints[{index}]"),
|
||||
op=constraint_op,
|
||||
value=_require_value(constraint_data, "value"),
|
||||
)
|
||||
)
|
||||
|
||||
policy_data = _require_mapping(_require_value(root, "policy"), "task.policy")
|
||||
policy_keep_if = _require_str(policy_data, "keep_if", "task.policy")
|
||||
if policy_keep_if != "better_primary":
|
||||
raise TaskValidationError("task.policy.keep_if must be better_primary")
|
||||
policy_on_failure = _require_str(policy_data, "on_failure", "task.policy")
|
||||
if policy_on_failure != "discard":
|
||||
raise TaskValidationError("task.policy.on_failure must be discard")
|
||||
policy = PolicySpec(
|
||||
keep_if=_require_str(policy_data, "keep_if", "task.policy"),
|
||||
keep_if=policy_keep_if,
|
||||
tie_breakers=_require_tie_breakers(policy_data, "tie_breakers", "task.policy"),
|
||||
on_failure=_require_str(policy_data, "on_failure", "task.policy"),
|
||||
on_failure=policy_on_failure,
|
||||
)
|
||||
|
||||
budget_data = _require_mapping(_require_value(root, "budget"), "task.budget")
|
||||
@ -164,4 +185,4 @@ def load_task(task_path: Path) -> TaskSpec:
|
||||
budget=budget,
|
||||
logging=logging,
|
||||
root_dir=task_path.parent,
|
||||
)
|
||||
)
|
||||
|
||||
@ -74,6 +74,17 @@ class TaskLoaderTest(unittest.TestCase):
|
||||
load_task(self.write_task(content))
|
||||
self.assertIn("direction", str(ctx.exception))
|
||||
|
||||
def test_rejects_malformed_yaml(self) -> None:
|
||||
content = VALID_TASK + " bad_indent: [\n"
|
||||
with self.assertRaises(TaskValidationError):
|
||||
load_task(self.write_task(content))
|
||||
|
||||
def test_rejects_invalid_enum_value(self) -> None:
|
||||
content = VALID_TASK.replace("mode: direct_edit", "mode: patch")
|
||||
with self.assertRaises(TaskValidationError) as ctx:
|
||||
load_task(self.write_task(content))
|
||||
self.assertIn("mutation.mode", str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user