CommonAutoRearsh/engine/orchestrator.py

253 lines
9.8 KiB
Python

from __future__ import annotations
import shutil
import tempfile
from dataclasses import replace
from hashlib import sha256
from pathlib import Path
from engine.artifact_manager import ArtifactManager
from engine.decision_engine import decide_candidate
from engine.models import BaselineSnapshot, DecisionResult, TaskSpec
from engine.mutation_engine import MutationValidationError, validate_candidate_changes
from engine.runner import run_command
from engine.scorer import parse_score_output
_SANDBOX_EXCLUDED_ROOTS = frozenset({"work", ".venv", ".pytest_cache"})
def _normalize_relative_path(raw_path: str) -> Path:
path = Path(raw_path)
return Path(*[part for part in path.parts if part not in ("", ".")])
def _validate_sandbox_relative_path(raw_path: str, field_name: str) -> Path:
path = Path(raw_path)
if path.is_absolute():
raise ValueError(f"{field_name} must be relative to the sandbox root")
if any(part == ".." for part in path.parts):
raise ValueError(f"{field_name} must not contain '..'")
return _normalize_relative_path(raw_path)
def _infer_repo_root(task: TaskSpec, candidate_paths: list[Path]) -> Path:
task_root = task.root_dir.resolve()
non_empty_paths = [path for path in candidate_paths if path.parts]
if not non_empty_paths:
return task_root
best_root = task_root
best_match_count = -1
candidate_root = task_root
while True:
match_count = sum(1 for relative_path in non_empty_paths if (candidate_root / relative_path).is_dir())
if match_count > best_match_count:
best_root = candidate_root
best_match_count = match_count
if candidate_root.parent == candidate_root:
break
candidate_root = candidate_root.parent
return best_root
def _is_sandbox_excluded_path(relative_path: Path) -> bool:
return bool(relative_path.parts) and relative_path.parts[0] in _SANDBOX_EXCLUDED_ROOTS
def _copy_repo_to_sandbox(repo_root: Path, sandbox_root: Path) -> None:
for child in repo_root.iterdir():
if child.name == ".git" or _is_sandbox_excluded_path(Path(child.name)):
continue
destination = sandbox_root / child.name
if child.is_dir():
shutil.copytree(child, destination, dirs_exist_ok=True)
continue
shutil.copy2(child, destination)
def _sandbox_task(task: TaskSpec, sandbox_root: Path, repo_root: Path) -> TaskSpec:
relative_task_root = task.root_dir.resolve().relative_to(repo_root)
return replace(task, root_dir=sandbox_root / relative_task_root)
def _sandbox_snapshot(task: TaskSpec, sandbox_task: TaskSpec, snapshot: BaselineSnapshot) -> BaselineSnapshot:
file_contents: dict[Path, str] = {}
file_hashes: dict[Path, str] = {}
for path, content in snapshot.file_contents.items():
relative_path = path.relative_to(task.root_dir)
sandbox_path = sandbox_task.root_dir / relative_path
file_contents[sandbox_path] = content
file_hashes[sandbox_path] = snapshot.file_hashes[path]
return BaselineSnapshot(file_contents=file_contents, file_hashes=file_hashes)
def _repo_file_hashes(root: Path) -> dict[Path, str]:
file_hashes: dict[Path, str] = {}
for path in root.rglob("*"):
if not path.is_file():
continue
relative_path = path.relative_to(root)
if ".git" in path.parts or _is_sandbox_excluded_path(relative_path):
continue
file_hashes[relative_path] = sha256(path.read_bytes()).hexdigest()
return file_hashes
def _validate_keepable_candidate(
task: TaskSpec,
sandbox_task: TaskSpec,
baseline_snapshot: BaselineSnapshot,
repo_root: Path,
sandbox_root: Path,
) -> None:
task_root_relative = task.root_dir.resolve().relative_to(repo_root.resolve())
allowed_relative_paths = {
task_root_relative / path.relative_to(task.root_dir) for path in baseline_snapshot.file_contents
}
allowed_relative_paths.update(
task_root_relative / path.relative_to(sandbox_task.root_dir)
for path in ArtifactManager(sandbox_task).resolve_paths()
)
baseline_hashes = _repo_file_hashes(repo_root)
candidate_hashes = _repo_file_hashes(sandbox_root)
for relative_path in sorted(set(baseline_hashes) | set(candidate_hashes)):
if baseline_hashes.get(relative_path) == candidate_hashes.get(relative_path):
continue
if relative_path in allowed_relative_paths:
continue
raise MutationValidationError(f"non-artifact change detected: {relative_path.as_posix()}")
def _validate_candidate_state(
task: TaskSpec,
sandbox_task: TaskSpec,
baseline_snapshot: BaselineSnapshot,
repo_root: Path,
sandbox_root: Path,
) -> None:
_validate_keepable_candidate(task, sandbox_task, baseline_snapshot, repo_root, sandbox_root)
validate_candidate_changes(task, baseline_snapshot, sandbox_task.root_dir)
def _validate_final_candidate_artifacts(
task: TaskSpec,
sandbox_task: TaskSpec,
baseline_snapshot: BaselineSnapshot,
) -> None:
validate_candidate_changes(task, baseline_snapshot, sandbox_task.root_dir)
def _resolve_sandbox_cwd(sandbox_root: Path, relative_cwd: Path, field_name: str) -> Path:
sandbox_cwd = sandbox_root / relative_cwd
if not sandbox_cwd.is_dir():
raise ValueError(f"{field_name} does not exist in sandbox: {relative_cwd.as_posix()}")
return sandbox_cwd
def _sync_artifacts_back(task: TaskSpec, sandbox_task: TaskSpec) -> None:
source_manager = ArtifactManager(sandbox_task)
target_manager = ArtifactManager(task)
source_paths = source_manager.resolve_paths()
source_relative_paths = {path.relative_to(sandbox_task.root_dir) for path in source_paths}
for path in source_paths:
relative_path = path.relative_to(sandbox_task.root_dir)
target_path = task.root_dir / relative_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with path.open("r", encoding="utf-8", newline="") as source_handle:
with target_path.open("w", encoding="utf-8", newline="") as target_handle:
target_handle.write(source_handle.read())
for path in target_manager.resolve_paths():
relative_path = path.relative_to(task.root_dir)
if relative_path in source_relative_paths:
continue
path.unlink()
def _crash(reason: str, baseline_score: float | None) -> DecisionResult:
return DecisionResult(
status="crash",
reason=reason,
baseline_score=baseline_score,
candidate_score=None,
)
def run_single_iteration(task: TaskSpec, baseline_score: float | None) -> DecisionResult:
manager = ArtifactManager(task)
baseline_snapshot = manager.snapshot()
try:
mutator_relative_cwd = _validate_sandbox_relative_path(task.mutator.cwd, "mutator.cwd")
runner_relative_cwd = _validate_sandbox_relative_path(task.runner.cwd, "runner.cwd")
except ValueError as exc:
return _crash(str(exc), baseline_score)
repo_root = _infer_repo_root(task, [mutator_relative_cwd, runner_relative_cwd])
with tempfile.TemporaryDirectory(prefix="autoresearch-orchestrator-") as sandbox_dir:
sandbox_root = Path(sandbox_dir)
_copy_repo_to_sandbox(repo_root, sandbox_root)
sandbox_task = _sandbox_task(task, sandbox_root, repo_root)
try:
mutator_cwd = _resolve_sandbox_cwd(sandbox_root, mutator_relative_cwd, "mutator.cwd")
except ValueError as exc:
return _crash(str(exc), baseline_score)
mutator_result = run_command(task.mutator.command, mutator_cwd, task.mutator.timeout_seconds)
if mutator_result.exit_code != 0:
return _crash(f"mutator failed with exit code {mutator_result.exit_code}", baseline_score)
try:
_validate_candidate_state(task, sandbox_task, baseline_snapshot, repo_root, sandbox_root)
except MutationValidationError as exc:
return DecisionResult(
status="discard",
reason=str(exc),
baseline_score=baseline_score,
candidate_score=None,
)
try:
runner_cwd = _resolve_sandbox_cwd(sandbox_root, runner_relative_cwd, "runner.cwd")
except ValueError as exc:
return _crash(str(exc), baseline_score)
runner_result = run_command(task.runner.command, runner_cwd, task.runner.timeout_seconds)
if runner_result.exit_code != 0:
return _crash(f"command failed with exit code {runner_result.exit_code}", baseline_score)
scorer_result = run_command(task.scorer.command, sandbox_root, task.scorer.timeout_seconds)
if scorer_result.exit_code != 0:
return _crash(f"scorer failed with exit code {scorer_result.exit_code}", baseline_score)
try:
candidate_score = parse_score_output(
scorer_result.stdout,
score_field=task.scorer.parse.score_field,
metrics_field=task.scorer.parse.metrics_field,
)
except ValueError as exc:
return _crash(f"score parse failed: {exc}", baseline_score)
decision = decide_candidate(
baseline=baseline_score,
candidate=candidate_score,
objective=task.objective,
constraints=task.constraints,
tie_breakers=task.policy.tie_breakers,
run_result=runner_result,
)
if decision.status == "keep":
try:
_validate_final_candidate_artifacts(task, sandbox_task, baseline_snapshot)
except MutationValidationError as exc:
return DecisionResult(
status="discard",
reason=str(exc),
baseline_score=baseline_score,
candidate_score=None,
)
_sync_artifacts_back(task, sandbox_task)
return decision