CommonAutoRearsh/tests/test_artifact_manager.py

113 lines
5.1 KiB
Python

from pathlib import Path
from hashlib import sha256
import tempfile
import unittest
from engine.artifact_manager import ArtifactManager
from engine.models import ArtifactSpec, BaselineSnapshot, TaskSpec
from engine.models import BudgetSpec, ConstraintSpec, LoggingSpec, MutationSpec, ObjectiveSpec, PolicySpec, RunnerSpec, ScorerParseSpec, ScorerSpec
def make_task(root_dir: Path) -> TaskSpec:
return TaskSpec(
id="demo",
description="Demo",
artifacts=ArtifactSpec(include=["artifacts/*.md"], exclude=["artifacts/ignore.md"], max_files_per_iteration=1),
mutation=MutationSpec(mode="direct_edit", allowed_file_types=[".md"], max_changed_lines=20),
runner=RunnerSpec(command="python -c \"print('run')\"", cwd=".", timeout_seconds=10),
scorer=ScorerSpec(
type="command",
command="python -c \"import json; print(json.dumps({'score': 1, 'metrics': {'violation_count': 0}}))\"",
parse=ScorerParseSpec(format="json", score_field="score", metrics_field="metrics"),
),
objective=ObjectiveSpec(primary_metric="score", direction="maximize"),
constraints=[ConstraintSpec(metric="violation_count", op="<=", value=0)],
policy=PolicySpec(keep_if="better_primary", tie_breakers=[], on_failure="discard"),
budget=BudgetSpec(max_iterations=1, max_failures=1),
logging=LoggingSpec(results_file="work/results.jsonl", candidate_dir="work/candidates"),
root_dir=root_dir,
)
class ArtifactManagerTest(unittest.TestCase):
def test_snapshot_and_restore(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
artifact_dir = root / "artifacts"
artifact_dir.mkdir()
target = artifact_dir / "sample.md"
target.write_text("hello\n", encoding="utf-8")
manager = ArtifactManager(make_task(root))
snapshot = manager.snapshot()
target.write_text("changed\n", encoding="utf-8")
manager.restore(snapshot)
self.assertEqual(target.read_text(encoding="utf-8"), "hello\n")
def test_restore_removes_newly_created_included_artifact(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
artifact_dir = root / "artifacts"
artifact_dir.mkdir()
target = artifact_dir / "sample.md"
target.write_text("hello\n", encoding="utf-8")
manager = ArtifactManager(make_task(root))
snapshot = manager.snapshot()
extra = artifact_dir / "new.md"
extra.write_text("new\n", encoding="utf-8")
manager.restore(snapshot)
self.assertFalse(extra.exists())
self.assertEqual(target.read_text(encoding="utf-8"), "hello\n")
def test_snapshot_and_restore_preserve_crlf_content(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
artifact_dir = root / "artifacts"
artifact_dir.mkdir()
target = artifact_dir / "sample.md"
original = "line1\r\nline2\r\n"
with target.open("w", encoding="utf-8", newline="") as handle:
handle.write(original)
manager = ArtifactManager(make_task(root))
snapshot = manager.snapshot()
self.assertEqual(snapshot.file_contents[target], original)
self.assertEqual(
snapshot.file_hashes[target],
sha256(original.encode("utf-8")).hexdigest(),
)
with target.open("w", encoding="utf-8", newline="") as handle:
handle.write("changed\r\n")
manager.restore(snapshot)
with target.open("r", encoding="utf-8", newline="") as handle:
restored = handle.read()
self.assertEqual(restored, original)
def test_resolve_paths_is_deterministic_and_respects_excludes(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
artifact_dir = root / "artifacts"
artifact_dir.mkdir()
(artifact_dir / "b.md").write_text("b\n", encoding="utf-8")
(artifact_dir / "ignore.md").write_text("ignore\n", encoding="utf-8")
(artifact_dir / "a.md").write_text("a\n", encoding="utf-8")
manager = ArtifactManager(make_task(root))
paths = manager.resolve_paths()
self.assertEqual(paths, [artifact_dir / "a.md", artifact_dir / "b.md"])
def test_diff_summary_contains_changed_line(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
artifact_dir = root / "artifacts"
artifact_dir.mkdir()
target = artifact_dir / "sample.md"
target.write_text("before\n", encoding="utf-8")
manager = ArtifactManager(make_task(root))
snapshot = manager.snapshot()
target.write_text("after\n", encoding="utf-8")
summary = manager.diff_summary(snapshot)
self.assertIn("-before", summary)
self.assertIn("+after", summary)
if __name__ == "__main__":
unittest.main()