389 lines
14 KiB
Python
389 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Set
|
|
|
|
from app.config import settings, software_config
|
|
from app.core.cad_task_router import CadTaskRouter, RouteNotFoundError
|
|
from app.core.plugin_callback_registry import PluginCallbackRegistry
|
|
from app.core.plugin_http_client import PluginHttpClient
|
|
from app.core.serial_batch_executor import SerialBatchExecutor
|
|
from app.models.cad_batch import (
|
|
BatchItem,
|
|
BatchItemStatus,
|
|
BatchJob,
|
|
BatchStatus,
|
|
BatchSubmitRequest,
|
|
PluginCallbackPayload,
|
|
)
|
|
from app.models.operation_log import OperationStatus
|
|
|
|
|
|
TERMINAL_BATCH_STATUSES: Set[BatchStatus] = {
|
|
BatchStatus.COMPLETED,
|
|
BatchStatus.COMPLETED_WITH_ERRORS,
|
|
BatchStatus.FAILED,
|
|
}
|
|
|
|
|
|
class CadBatchManager:
|
|
def __init__(
|
|
self,
|
|
task_router: Optional[CadTaskRouter] = None,
|
|
plugin_client: Optional[PluginHttpClient] = None,
|
|
callback_registry: Optional[PluginCallbackRegistry] = None,
|
|
):
|
|
self._router = task_router or CadTaskRouter()
|
|
self._plugin_client = plugin_client or PluginHttpClient()
|
|
self._callback_registry = callback_registry or PluginCallbackRegistry()
|
|
self._executor = SerialBatchExecutor(self, self._plugin_client, self._callback_registry)
|
|
|
|
self._batches: Dict[str, BatchJob] = {}
|
|
self._items: Dict[str, BatchItem] = {}
|
|
|
|
self._websocket_manager = None
|
|
self._log_manager = None
|
|
self._lock = asyncio.Lock()
|
|
|
|
def set_websocket_manager(self, websocket_manager):
|
|
self._websocket_manager = websocket_manager
|
|
|
|
def set_log_manager(self, log_manager):
|
|
self._log_manager = log_manager
|
|
|
|
async def start(self):
|
|
await self._executor.start()
|
|
|
|
async def stop(self):
|
|
await self._executor.stop()
|
|
await self._callback_registry.clear()
|
|
|
|
async def create_batch(self, request: BatchSubmitRequest, submitter_id: Optional[str] = None) -> BatchJob:
|
|
batch_id = str(uuid.uuid4())
|
|
batch = BatchJob(
|
|
id=batch_id,
|
|
name=request.batch_name,
|
|
submitter_id=submitter_id,
|
|
metadata=request.metadata,
|
|
)
|
|
|
|
enqueue_items: List[str] = []
|
|
|
|
async with self._lock:
|
|
self._batches[batch_id] = batch
|
|
|
|
for idx, submit_item in enumerate(request.items):
|
|
item_id = str(uuid.uuid4())
|
|
|
|
status = BatchItemStatus.QUEUED
|
|
file_extension = None
|
|
software_id = None
|
|
max_retries = 0
|
|
error_message = None
|
|
|
|
try:
|
|
file_extension, software_id = self._router.resolve(submit_item.model_path)
|
|
plugin_cfg = software_config.get_plugin_config(software_id) or {}
|
|
max_retries = int(plugin_cfg.get("max_retries", 0))
|
|
except RouteNotFoundError as exc:
|
|
status = BatchItemStatus.FAILED
|
|
error_message = str(exc)
|
|
|
|
item = BatchItem(
|
|
id=item_id,
|
|
batch_id=batch_id,
|
|
sequence=idx,
|
|
model_path=submit_item.model_path,
|
|
file_extension=file_extension,
|
|
software_id=software_id,
|
|
task_type=submit_item.task_type,
|
|
task_params=submit_item.task_params,
|
|
status=status,
|
|
max_retries=max_retries,
|
|
error_message=error_message,
|
|
finished_at=datetime.now() if status == BatchItemStatus.FAILED else None,
|
|
)
|
|
|
|
self._items[item_id] = item
|
|
batch.item_ids.append(item_id)
|
|
|
|
if status == BatchItemStatus.QUEUED:
|
|
enqueue_items.append(item_id)
|
|
|
|
self._recalculate_batch(batch_id)
|
|
batch_snapshot = self._batches[batch_id].model_copy(deep=True)
|
|
|
|
for item_id in enqueue_items:
|
|
await self._executor.enqueue(item_id)
|
|
|
|
await self._notify_websocket(
|
|
"batch_created",
|
|
{
|
|
"batch": batch_snapshot.model_dump(mode="json"),
|
|
"items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_id)],
|
|
},
|
|
)
|
|
|
|
await self._log_system_operation(
|
|
operation="batch_create",
|
|
details=f"Created batch {batch_id} with {batch_snapshot.total_count} items",
|
|
status=OperationStatus.SUCCESS,
|
|
extra_data={
|
|
"batch_id": batch_id,
|
|
"submitter_id": submitter_id,
|
|
"total_count": batch_snapshot.total_count,
|
|
},
|
|
)
|
|
|
|
return batch_snapshot
|
|
|
|
async def get_batch(self, batch_id: str) -> Optional[BatchJob]:
|
|
async with self._lock:
|
|
batch = self._batches.get(batch_id)
|
|
return batch.model_copy(deep=True) if batch else None
|
|
|
|
async def get_item(self, item_id: str) -> Optional[BatchItem]:
|
|
async with self._lock:
|
|
item = self._items.get(item_id)
|
|
return item.model_copy(deep=True) if item else None
|
|
|
|
async def get_batch_items(self, batch_id: str) -> List[BatchItem]:
|
|
async with self._lock:
|
|
batch = self._batches.get(batch_id)
|
|
if not batch:
|
|
return []
|
|
return [self._items[item_id].model_copy(deep=True) for item_id in batch.item_ids if item_id in self._items]
|
|
|
|
def get_callback_endpoint_url(self) -> str:
|
|
base = getattr(settings, "batch_callback_base_url", None) or f"http://localhost:{settings.port}"
|
|
return f"{base.rstrip('/')}/api/v1/plugin-callbacks/task-result"
|
|
|
|
def get_callback_timeout_sec(self, software_id: str) -> int:
|
|
plugin_cfg = software_config.get_plugin_config(software_id) or {}
|
|
return int(plugin_cfg.get("callback_timeout_sec", 60))
|
|
|
|
def get_retry_backoff_sec(self, software_id: str, attempt_idx: int) -> int:
|
|
plugin_cfg = software_config.get_plugin_config(software_id) or {}
|
|
backoff = plugin_cfg.get("retry_backoff_sec", [1, 3, 5])
|
|
if not isinstance(backoff, list) or not backoff:
|
|
return 0
|
|
idx = min(attempt_idx, len(backoff) - 1)
|
|
try:
|
|
return int(backoff[idx])
|
|
except (TypeError, ValueError):
|
|
return 0
|
|
|
|
def get_task_completion_mode(self, software_id: str, task_type: str) -> str:
|
|
task_cfg = software_config.get_plugin_task_config(software_id, task_type) or {}
|
|
mode = task_cfg.get("completion_mode", "callback")
|
|
if mode not in {"callback", "sync"}:
|
|
return "callback"
|
|
return mode
|
|
|
|
def validate_callback_token(self, software_id: str, token: Optional[str]) -> bool:
|
|
plugin_cfg = software_config.get_plugin_config(software_id) or {}
|
|
expected = plugin_cfg.get("callback_token")
|
|
if not expected:
|
|
return False
|
|
return token == expected
|
|
|
|
async def handle_plugin_callback(self, payload: PluginCallbackPayload) -> bool:
|
|
accepted = await self._callback_registry.handle_callback(payload)
|
|
if not accepted:
|
|
await self._log_system_operation(
|
|
operation="plugin_callback_duplicate",
|
|
details=f"Duplicate callback ignored for execution_id={payload.execution_id}",
|
|
status=OperationStatus.SUCCESS,
|
|
extra_data={"execution_id": payload.execution_id, "software_id": payload.software_id},
|
|
)
|
|
return accepted
|
|
|
|
async def mark_item_dispatching(self, item_id: str, execution_id: str, attempt: int):
|
|
await self._update_item(
|
|
item_id,
|
|
{
|
|
"status": BatchItemStatus.DISPATCHING,
|
|
"execution_id": execution_id,
|
|
"retry_count": attempt,
|
|
"started_at": datetime.now(),
|
|
"error_message": None,
|
|
},
|
|
)
|
|
|
|
async def mark_item_waiting_callback(self, item_id: str, execution_id: str, plugin_response: dict):
|
|
await self._update_item(
|
|
item_id,
|
|
{
|
|
"status": BatchItemStatus.WAITING_CALLBACK,
|
|
"execution_id": execution_id,
|
|
"plugin_response": plugin_response or {},
|
|
},
|
|
)
|
|
|
|
async def mark_item_retrying(self, item_id: str, error_message: str, next_attempt: int):
|
|
await self._update_item(
|
|
item_id,
|
|
{
|
|
"status": BatchItemStatus.QUEUED,
|
|
"retry_count": next_attempt,
|
|
"error_message": error_message,
|
|
},
|
|
)
|
|
|
|
async def mark_item_succeeded(self, item_id: str, result: dict):
|
|
await self._update_item(
|
|
item_id,
|
|
{
|
|
"status": BatchItemStatus.SUCCEEDED,
|
|
"result": result or {},
|
|
"finished_at": datetime.now(),
|
|
"error_message": None,
|
|
},
|
|
)
|
|
|
|
async def mark_item_failed(self, item_id: str, error_message: str):
|
|
await self._update_item(
|
|
item_id,
|
|
{
|
|
"status": BatchItemStatus.FAILED,
|
|
"error_message": error_message,
|
|
"finished_at": datetime.now(),
|
|
},
|
|
)
|
|
|
|
async def _update_item(self, item_id: str, changes: dict):
|
|
batch_snapshot = None
|
|
item_snapshot = None
|
|
emit_batch_completed = False
|
|
|
|
async with self._lock:
|
|
item = self._items.get(item_id)
|
|
if not item:
|
|
return
|
|
|
|
batch = self._batches.get(item.batch_id)
|
|
if not batch:
|
|
return
|
|
|
|
prev_batch_status = batch.status
|
|
|
|
for key, value in changes.items():
|
|
setattr(item, key, value)
|
|
|
|
self._recalculate_batch(item.batch_id)
|
|
|
|
item_snapshot = item.model_copy(deep=True)
|
|
batch_snapshot = self._batches[item.batch_id].model_copy(deep=True)
|
|
|
|
emit_batch_completed = (
|
|
prev_batch_status not in TERMINAL_BATCH_STATUSES
|
|
and batch_snapshot.status in TERMINAL_BATCH_STATUSES
|
|
)
|
|
|
|
await self._notify_websocket(
|
|
"batch_item_update",
|
|
{
|
|
"item": item_snapshot.model_dump(mode="json"),
|
|
"batch": batch_snapshot.model_dump(mode="json"),
|
|
},
|
|
)
|
|
|
|
if item_snapshot.status == BatchItemStatus.FAILED:
|
|
op_status = OperationStatus.FAILED
|
|
else:
|
|
op_status = OperationStatus.SUCCESS
|
|
|
|
await self._log_system_operation(
|
|
operation="batch_item_update",
|
|
details=f"Item {item_snapshot.id} status changed to {item_snapshot.status.value}",
|
|
status=op_status,
|
|
extra_data={
|
|
"batch_id": batch_snapshot.id,
|
|
"item_id": item_snapshot.id,
|
|
"status": item_snapshot.status.value,
|
|
"software_id": item_snapshot.software_id,
|
|
"error_message": item_snapshot.error_message,
|
|
},
|
|
)
|
|
|
|
if emit_batch_completed:
|
|
await self._notify_websocket(
|
|
"batch_completed",
|
|
{
|
|
"batch": batch_snapshot.model_dump(mode="json"),
|
|
"items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_snapshot.id)],
|
|
},
|
|
)
|
|
|
|
def _recalculate_batch(self, batch_id: str):
|
|
batch = self._batches[batch_id]
|
|
items = [self._items[item_id] for item_id in batch.item_ids if item_id in self._items]
|
|
|
|
total = len(items)
|
|
queued = sum(1 for i in items if i.status == BatchItemStatus.QUEUED)
|
|
running = sum(
|
|
1
|
|
for i in items
|
|
if i.status in {BatchItemStatus.DISPATCHING, BatchItemStatus.WAITING_CALLBACK}
|
|
)
|
|
succeeded = sum(1 for i in items if i.status == BatchItemStatus.SUCCEEDED)
|
|
failed = sum(1 for i in items if i.status == BatchItemStatus.FAILED)
|
|
|
|
batch.total_count = total
|
|
batch.queued_count = queued
|
|
batch.running_count = running
|
|
batch.succeeded_count = succeeded
|
|
batch.failed_count = failed
|
|
|
|
if running > 0:
|
|
batch.status = BatchStatus.RUNNING
|
|
if batch.started_at is None:
|
|
batch.started_at = datetime.now()
|
|
batch.finished_at = None
|
|
return
|
|
|
|
if queued > 0:
|
|
batch.status = BatchStatus.PENDING
|
|
return
|
|
|
|
if total == 0:
|
|
batch.status = BatchStatus.FAILED
|
|
batch.finished_at = datetime.now()
|
|
return
|
|
|
|
if failed == total:
|
|
batch.status = BatchStatus.FAILED
|
|
elif failed > 0:
|
|
batch.status = BatchStatus.COMPLETED_WITH_ERRORS
|
|
else:
|
|
batch.status = BatchStatus.COMPLETED
|
|
|
|
if batch.started_at is None:
|
|
batch.started_at = datetime.now()
|
|
batch.finished_at = datetime.now()
|
|
|
|
async def _notify_websocket(self, message_type: str, data: dict):
|
|
if self._websocket_manager:
|
|
await self._websocket_manager.broadcast(
|
|
{
|
|
"type": message_type,
|
|
"data": data,
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
)
|
|
|
|
async def _log_system_operation(self, operation: str, details: str, status=OperationStatus.SUCCESS, **kwargs):
|
|
if self._log_manager:
|
|
await self._log_manager.log_system_operation(
|
|
operation=operation,
|
|
details=details,
|
|
status=status,
|
|
operation_category="batch_processing",
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
cad_batch_manager = CadBatchManager()
|