149 lines
5.7 KiB
Python
149 lines
5.7 KiB
Python
import asyncio
|
|
from copy import deepcopy
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
|
|
from app.config import software_config
|
|
from app.core.cad_batch_manager import CadBatchManager
|
|
from app.core.cad_task_router import CadTaskRouter
|
|
from app.core.plugin_callback_registry import PluginCallbackRegistry
|
|
from app.models.cad_batch import (
|
|
BatchStatus,
|
|
BatchSubmitItem,
|
|
BatchSubmitRequest,
|
|
PluginCallbackPayload,
|
|
PluginResultStatus,
|
|
)
|
|
|
|
|
|
class FakePluginHttpClient:
|
|
def __init__(self, callback_registry: PluginCallbackRegistry):
|
|
self._registry = callback_registry
|
|
self.calls = []
|
|
|
|
async def submit_task(self, software_id: str, task_type: str, payload: dict) -> dict:
|
|
self.calls.append(payload)
|
|
behavior = payload.get("task_params", {}).get("behavior", "success")
|
|
attempt = payload.get("attempt", 0)
|
|
|
|
if behavior == "timeout_once_then_success" and attempt == 0:
|
|
# no callback; force timeout for attempt 0
|
|
return {"accepted": True}
|
|
|
|
if behavior == "always_fail":
|
|
status = PluginResultStatus.FAILED
|
|
result = {}
|
|
error_message = "simulated plugin failure"
|
|
else:
|
|
status = PluginResultStatus.SUCCESS
|
|
result = {"exported": True, "model": payload.get("model_path")}
|
|
error_message = None
|
|
|
|
async def _emit_callback():
|
|
await asyncio.sleep(0.01)
|
|
await self._registry.handle_callback(
|
|
PluginCallbackPayload(
|
|
execution_id=payload["execution_id"],
|
|
software_id=software_id,
|
|
status=status,
|
|
error_message=error_message,
|
|
result=result,
|
|
finished_at=datetime.now(),
|
|
token="ignored-in-registry",
|
|
)
|
|
)
|
|
|
|
asyncio.create_task(_emit_callback())
|
|
return {"accepted": True}
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _configure_short_timeouts():
|
|
backup = deepcopy(software_config.load_config())
|
|
software_config._config.setdefault("plugins", {}).setdefault("creo", {})["callback_timeout_sec"] = 1
|
|
software_config._config.setdefault("plugins", {}).setdefault("creo", {})["retry_backoff_sec"] = [0, 0]
|
|
software_config._config.setdefault("plugins", {}).setdefault("creo", {})["max_retries"] = 1
|
|
software_config._config.setdefault("plugins", {}).setdefault("revit", {})["callback_timeout_sec"] = 1
|
|
software_config._config.setdefault("plugins", {}).setdefault("revit", {})["retry_backoff_sec"] = [0, 0]
|
|
software_config._config.setdefault("plugins", {}).setdefault("revit", {})["max_retries"] = 1
|
|
yield
|
|
software_config._config = backup
|
|
|
|
|
|
async def _wait_batch_terminal(manager: CadBatchManager, batch_id: str, timeout_sec: float = 8.0):
|
|
start = asyncio.get_event_loop().time()
|
|
while asyncio.get_event_loop().time() - start < timeout_sec:
|
|
batch = await manager.get_batch(batch_id)
|
|
if batch and batch.status in {BatchStatus.COMPLETED, BatchStatus.COMPLETED_WITH_ERRORS, BatchStatus.FAILED}:
|
|
return batch
|
|
await asyncio.sleep(0.02)
|
|
raise TimeoutError(f"batch {batch_id} did not finish within {timeout_sec}s")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_serial_executor_fifo_and_continue_after_failure():
|
|
callback_registry = PluginCallbackRegistry()
|
|
fake_client = FakePluginHttpClient(callback_registry)
|
|
router = CadTaskRouter({".prt": "creo", ".rvt": "revit"})
|
|
|
|
manager = CadBatchManager(task_router=router, plugin_client=fake_client, callback_registry=callback_registry)
|
|
await manager.start()
|
|
|
|
request = BatchSubmitRequest(
|
|
items=[
|
|
BatchSubmitItem(model_path="a.prt", task_type="export", task_params={"behavior": "success"}),
|
|
BatchSubmitItem(model_path="b.prt", task_type="export", task_params={"behavior": "always_fail"}),
|
|
BatchSubmitItem(model_path="c.rvt", task_type="export", task_params={"behavior": "success"}),
|
|
],
|
|
batch_name="fifo-case",
|
|
)
|
|
|
|
try:
|
|
batch = await manager.create_batch(request, submitter_id="tester")
|
|
final_batch = await _wait_batch_terminal(manager, batch.id)
|
|
items = await manager.get_batch_items(batch.id)
|
|
|
|
assert final_batch.status == BatchStatus.COMPLETED_WITH_ERRORS
|
|
assert [item.status.value for item in items] == ["succeeded", "failed", "succeeded"]
|
|
|
|
call_models = [call["model_path"] for call in fake_client.calls]
|
|
assert call_models[0] == "a.prt"
|
|
assert call_models[-1] == "c.rvt"
|
|
assert max(i for i, path in enumerate(call_models) if path == "b.prt") < min(
|
|
i for i, path in enumerate(call_models) if path == "c.rvt"
|
|
)
|
|
finally:
|
|
await manager.stop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_serial_executor_retry_after_timeout_then_success():
|
|
callback_registry = PluginCallbackRegistry()
|
|
fake_client = FakePluginHttpClient(callback_registry)
|
|
router = CadTaskRouter({".prt": "creo"})
|
|
|
|
manager = CadBatchManager(task_router=router, plugin_client=fake_client, callback_registry=callback_registry)
|
|
await manager.start()
|
|
|
|
request = BatchSubmitRequest(
|
|
items=[
|
|
BatchSubmitItem(
|
|
model_path="timeout_once.prt",
|
|
task_type="export",
|
|
task_params={"behavior": "timeout_once_then_success"},
|
|
)
|
|
]
|
|
)
|
|
|
|
try:
|
|
batch = await manager.create_batch(request, submitter_id="tester")
|
|
final_batch = await _wait_batch_terminal(manager, batch.id)
|
|
items = await manager.get_batch_items(batch.id)
|
|
|
|
assert final_batch.status == BatchStatus.COMPLETED
|
|
assert items[0].status.value == "succeeded"
|
|
assert len(fake_client.calls) >= 2
|
|
finally:
|
|
await manager.stop()
|