- Added configuration for file storage and software plugins in `software_config.yaml`. - Created core components for batch processing including `CadBatchManager`, `CadTaskRouter`, and `SerialBatchExecutor`. - Implemented plugin callback handling with `PluginCallbackRegistry` and HTTP client for task submission. - Developed API endpoint for receiving plugin callbacks in `plugin_callbacks.py`. - Enhanced data models for batch processing including `BatchJob`, `BatchItem`, and callback payloads. - Introduced WebSocket support for real-time updates on batch processing status. - Added comprehensive tests for routing, callback API, and serial executor behavior. - Documented the implementation plan and core execution rules in `cad-batch-plan.md`.
106 lines
3.3 KiB
Python
106 lines
3.3 KiB
Python
"""WebSocket connection manager."""
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Dict, List
|
|
|
|
from fastapi import WebSocket
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DateTimeEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
return super().default(obj)
|
|
|
|
|
|
class MessageType(str, Enum):
|
|
SOFTWARE_STATUS = "software_status"
|
|
PROCESS_UPDATE = "process_update"
|
|
ERROR = "error"
|
|
INFO = "info"
|
|
HEARTBEAT = "heartbeat"
|
|
|
|
TASK_UPDATE = "task_update"
|
|
SOFTWARE_STARTED = "software_started"
|
|
SOFTWARE_START_FAILED = "software_start_failed"
|
|
SOFTWARE_LIST_UPDATE = "software_list_update"
|
|
|
|
LOG_OPERATION = "log_operation"
|
|
LOG_RECORDED = "log_recorded"
|
|
|
|
BATCH_CREATED = "batch_created"
|
|
BATCH_ITEM_UPDATE = "batch_item_update"
|
|
BATCH_COMPLETED = "batch_completed"
|
|
|
|
|
|
class WebSocketManager:
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
self.connection_users: Dict[str, str] = {}
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str, user_id: str = None):
|
|
await websocket.accept()
|
|
self.active_connections[client_id] = websocket
|
|
if user_id:
|
|
self.connection_users[client_id] = user_id
|
|
|
|
logger.info("WebSocket connected: client_id=%s user_id=%s", client_id, user_id)
|
|
|
|
await self.send_personal_message(
|
|
{
|
|
"type": MessageType.INFO,
|
|
"message": "WebSocket connected",
|
|
"timestamp": self._get_timestamp(),
|
|
},
|
|
client_id,
|
|
)
|
|
|
|
def disconnect(self, client_id: str):
|
|
self.active_connections.pop(client_id, None)
|
|
self.connection_users.pop(client_id, None)
|
|
logger.info("WebSocket disconnected: client_id=%s", client_id)
|
|
|
|
async def send_personal_message(self, message: dict, client_id: str):
|
|
websocket = self.active_connections.get(client_id)
|
|
if not websocket:
|
|
return
|
|
|
|
try:
|
|
await websocket.send_text(json.dumps(message, ensure_ascii=False, cls=DateTimeEncoder))
|
|
except Exception as exc:
|
|
logger.error("Failed to send websocket message to %s: %s", client_id, exc)
|
|
self.disconnect(client_id)
|
|
|
|
async def broadcast(self, message: dict):
|
|
if not self.active_connections:
|
|
return
|
|
|
|
tasks = [self.send_personal_message(message, cid) for cid in list(self.active_connections.keys())]
|
|
if tasks:
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
async def broadcast_to_user(self, message: dict, user_id: str):
|
|
target_clients = [cid for cid, uid in self.connection_users.items() if uid == user_id]
|
|
tasks = [self.send_personal_message(message, cid) for cid in target_clients]
|
|
if tasks:
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
def get_active_connections_count(self) -> int:
|
|
return len(self.active_connections)
|
|
|
|
def get_connected_users(self) -> List[str]:
|
|
return list(set(self.connection_users.values()))
|
|
|
|
def _get_timestamp(self) -> str:
|
|
return datetime.now().isoformat()
|
|
|
|
|
|
websocket_manager = WebSocketManager()
|