from __future__ import annotations import asyncio import uuid from datetime import datetime from typing import Dict, List, Optional, Set from app.config import settings, software_config from app.core.cad_task_router import CadTaskRouter, RouteNotFoundError from app.core.plugin_callback_registry import PluginCallbackRegistry from app.core.plugin_http_client import PluginHttpClient from app.core.serial_batch_executor import SerialBatchExecutor from app.models.cad_batch import ( BatchItem, BatchItemStatus, BatchJob, BatchStatus, BatchSubmitRequest, PluginCallbackPayload, ) from app.models.operation_log import OperationStatus TERMINAL_BATCH_STATUSES: Set[BatchStatus] = { BatchStatus.COMPLETED, BatchStatus.COMPLETED_WITH_ERRORS, BatchStatus.FAILED, } class CadBatchManager: def __init__( self, task_router: Optional[CadTaskRouter] = None, plugin_client: Optional[PluginHttpClient] = None, callback_registry: Optional[PluginCallbackRegistry] = None, ): self._router = task_router or CadTaskRouter() self._plugin_client = plugin_client or PluginHttpClient() self._callback_registry = callback_registry or PluginCallbackRegistry() self._executor = SerialBatchExecutor(self, self._plugin_client, self._callback_registry) self._batches: Dict[str, BatchJob] = {} self._items: Dict[str, BatchItem] = {} self._websocket_manager = None self._log_manager = None self._lock = asyncio.Lock() def set_websocket_manager(self, websocket_manager): self._websocket_manager = websocket_manager def set_log_manager(self, log_manager): self._log_manager = log_manager async def start(self): await self._executor.start() async def stop(self): await self._executor.stop() await self._callback_registry.clear() async def create_batch(self, request: BatchSubmitRequest, submitter_id: Optional[str] = None) -> BatchJob: batch_id = str(uuid.uuid4()) batch = BatchJob( id=batch_id, name=request.batch_name, submitter_id=submitter_id, metadata=request.metadata, ) enqueue_items: List[str] = [] async with self._lock: self._batches[batch_id] = batch for idx, submit_item in enumerate(request.items): item_id = str(uuid.uuid4()) status = BatchItemStatus.QUEUED file_extension = None software_id = None max_retries = 0 error_message = None try: file_extension, software_id = self._router.resolve(submit_item.model_path) plugin_cfg = software_config.get_plugin_config(software_id) or {} max_retries = int(plugin_cfg.get("max_retries", 0)) except RouteNotFoundError as exc: status = BatchItemStatus.FAILED error_message = str(exc) item = BatchItem( id=item_id, batch_id=batch_id, sequence=idx, model_path=submit_item.model_path, file_extension=file_extension, software_id=software_id, task_type=submit_item.task_type, task_params=submit_item.task_params, status=status, max_retries=max_retries, error_message=error_message, finished_at=datetime.now() if status == BatchItemStatus.FAILED else None, ) self._items[item_id] = item batch.item_ids.append(item_id) if status == BatchItemStatus.QUEUED: enqueue_items.append(item_id) self._recalculate_batch(batch_id) batch_snapshot = self._batches[batch_id].model_copy(deep=True) for item_id in enqueue_items: await self._executor.enqueue(item_id) await self._notify_websocket( "batch_created", { "batch": batch_snapshot.model_dump(mode="json"), "items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_id)], }, ) await self._log_system_operation( operation="batch_create", details=f"Created batch {batch_id} with {batch_snapshot.total_count} items", status=OperationStatus.SUCCESS, extra_data={ "batch_id": batch_id, "submitter_id": submitter_id, "total_count": batch_snapshot.total_count, }, ) return batch_snapshot async def get_batch(self, batch_id: str) -> Optional[BatchJob]: async with self._lock: batch = self._batches.get(batch_id) return batch.model_copy(deep=True) if batch else None async def get_item(self, item_id: str) -> Optional[BatchItem]: async with self._lock: item = self._items.get(item_id) return item.model_copy(deep=True) if item else None async def get_batch_items(self, batch_id: str) -> List[BatchItem]: async with self._lock: batch = self._batches.get(batch_id) if not batch: return [] return [self._items[item_id].model_copy(deep=True) for item_id in batch.item_ids if item_id in self._items] def get_callback_endpoint_url(self) -> str: base = getattr(settings, "batch_callback_base_url", None) or f"http://localhost:{settings.port}" return f"{base.rstrip('/')}/api/v1/plugin-callbacks/task-result" def get_callback_timeout_sec(self, software_id: str) -> int: plugin_cfg = software_config.get_plugin_config(software_id) or {} return int(plugin_cfg.get("callback_timeout_sec", 60)) def get_retry_backoff_sec(self, software_id: str, attempt_idx: int) -> int: plugin_cfg = software_config.get_plugin_config(software_id) or {} backoff = plugin_cfg.get("retry_backoff_sec", [1, 3, 5]) if not isinstance(backoff, list) or not backoff: return 0 idx = min(attempt_idx, len(backoff) - 1) try: return int(backoff[idx]) except (TypeError, ValueError): return 0 def get_task_completion_mode(self, software_id: str, task_type: str) -> str: task_cfg = software_config.get_plugin_task_config(software_id, task_type) or {} mode = task_cfg.get("completion_mode", "callback") if mode not in {"callback", "sync"}: return "callback" return mode def validate_callback_token(self, software_id: str, token: Optional[str]) -> bool: plugin_cfg = software_config.get_plugin_config(software_id) or {} expected = plugin_cfg.get("callback_token") if not expected: return False return token == expected async def handle_plugin_callback(self, payload: PluginCallbackPayload) -> bool: accepted = await self._callback_registry.handle_callback(payload) if not accepted: await self._log_system_operation( operation="plugin_callback_duplicate", details=f"Duplicate callback ignored for execution_id={payload.execution_id}", status=OperationStatus.SUCCESS, extra_data={"execution_id": payload.execution_id, "software_id": payload.software_id}, ) return accepted async def mark_item_dispatching(self, item_id: str, execution_id: str, attempt: int): await self._update_item( item_id, { "status": BatchItemStatus.DISPATCHING, "execution_id": execution_id, "retry_count": attempt, "started_at": datetime.now(), "error_message": None, }, ) async def mark_item_waiting_callback(self, item_id: str, execution_id: str, plugin_response: dict): await self._update_item( item_id, { "status": BatchItemStatus.WAITING_CALLBACK, "execution_id": execution_id, "plugin_response": plugin_response or {}, }, ) async def mark_item_retrying(self, item_id: str, error_message: str, next_attempt: int): await self._update_item( item_id, { "status": BatchItemStatus.QUEUED, "retry_count": next_attempt, "error_message": error_message, }, ) async def mark_item_succeeded(self, item_id: str, result: dict): await self._update_item( item_id, { "status": BatchItemStatus.SUCCEEDED, "result": result or {}, "finished_at": datetime.now(), "error_message": None, }, ) async def mark_item_failed(self, item_id: str, error_message: str): await self._update_item( item_id, { "status": BatchItemStatus.FAILED, "error_message": error_message, "finished_at": datetime.now(), }, ) async def _update_item(self, item_id: str, changes: dict): batch_snapshot = None item_snapshot = None emit_batch_completed = False async with self._lock: item = self._items.get(item_id) if not item: return batch = self._batches.get(item.batch_id) if not batch: return prev_batch_status = batch.status for key, value in changes.items(): setattr(item, key, value) self._recalculate_batch(item.batch_id) item_snapshot = item.model_copy(deep=True) batch_snapshot = self._batches[item.batch_id].model_copy(deep=True) emit_batch_completed = ( prev_batch_status not in TERMINAL_BATCH_STATUSES and batch_snapshot.status in TERMINAL_BATCH_STATUSES ) await self._notify_websocket( "batch_item_update", { "item": item_snapshot.model_dump(mode="json"), "batch": batch_snapshot.model_dump(mode="json"), }, ) if item_snapshot.status == BatchItemStatus.FAILED: op_status = OperationStatus.FAILED else: op_status = OperationStatus.SUCCESS await self._log_system_operation( operation="batch_item_update", details=f"Item {item_snapshot.id} status changed to {item_snapshot.status.value}", status=op_status, extra_data={ "batch_id": batch_snapshot.id, "item_id": item_snapshot.id, "status": item_snapshot.status.value, "software_id": item_snapshot.software_id, "error_message": item_snapshot.error_message, }, ) if emit_batch_completed: await self._notify_websocket( "batch_completed", { "batch": batch_snapshot.model_dump(mode="json"), "items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_snapshot.id)], }, ) def _recalculate_batch(self, batch_id: str): batch = self._batches[batch_id] items = [self._items[item_id] for item_id in batch.item_ids if item_id in self._items] total = len(items) queued = sum(1 for i in items if i.status == BatchItemStatus.QUEUED) running = sum( 1 for i in items if i.status in {BatchItemStatus.DISPATCHING, BatchItemStatus.WAITING_CALLBACK} ) succeeded = sum(1 for i in items if i.status == BatchItemStatus.SUCCEEDED) failed = sum(1 for i in items if i.status == BatchItemStatus.FAILED) batch.total_count = total batch.queued_count = queued batch.running_count = running batch.succeeded_count = succeeded batch.failed_count = failed if running > 0: batch.status = BatchStatus.RUNNING if batch.started_at is None: batch.started_at = datetime.now() batch.finished_at = None return if queued > 0: batch.status = BatchStatus.PENDING return if total == 0: batch.status = BatchStatus.FAILED batch.finished_at = datetime.now() return if failed == total: batch.status = BatchStatus.FAILED elif failed > 0: batch.status = BatchStatus.COMPLETED_WITH_ERRORS else: batch.status = BatchStatus.COMPLETED if batch.started_at is None: batch.started_at = datetime.now() batch.finished_at = datetime.now() async def _notify_websocket(self, message_type: str, data: dict): if self._websocket_manager: await self._websocket_manager.broadcast( { "type": message_type, "data": data, "timestamp": datetime.now().isoformat(), } ) async def _log_system_operation(self, operation: str, details: str, status=OperationStatus.SUCCESS, **kwargs): if self._log_manager: await self._log_manager.log_system_operation( operation=operation, details=details, status=status, operation_category="batch_processing", **kwargs, ) cad_batch_manager = CadBatchManager()