CadHubManage/app/core/websocket_manager.py
sladro b19ef1467a feat: Implement CAD batch processing framework with plugin callback handling
- Added configuration for file storage and software plugins in `software_config.yaml`.
- Created core components for batch processing including `CadBatchManager`, `CadTaskRouter`, and `SerialBatchExecutor`.
- Implemented plugin callback handling with `PluginCallbackRegistry` and HTTP client for task submission.
- Developed API endpoint for receiving plugin callbacks in `plugin_callbacks.py`.
- Enhanced data models for batch processing including `BatchJob`, `BatchItem`, and callback payloads.
- Introduced WebSocket support for real-time updates on batch processing status.
- Added comprehensive tests for routing, callback API, and serial executor behavior.
- Documented the implementation plan and core execution rules in `cad-batch-plan.md`.
2026-03-01 08:48:10 +08:00

106 lines
3.3 KiB
Python

"""WebSocket connection manager."""
from datetime import datetime
from enum import Enum
import asyncio
import json
import logging
from typing import Dict, List
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime):
return obj.isoformat()
return super().default(obj)
class MessageType(str, Enum):
SOFTWARE_STATUS = "software_status"
PROCESS_UPDATE = "process_update"
ERROR = "error"
INFO = "info"
HEARTBEAT = "heartbeat"
TASK_UPDATE = "task_update"
SOFTWARE_STARTED = "software_started"
SOFTWARE_START_FAILED = "software_start_failed"
SOFTWARE_LIST_UPDATE = "software_list_update"
LOG_OPERATION = "log_operation"
LOG_RECORDED = "log_recorded"
BATCH_CREATED = "batch_created"
BATCH_ITEM_UPDATE = "batch_item_update"
BATCH_COMPLETED = "batch_completed"
class WebSocketManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.connection_users: Dict[str, str] = {}
async def connect(self, websocket: WebSocket, client_id: str, user_id: str = None):
await websocket.accept()
self.active_connections[client_id] = websocket
if user_id:
self.connection_users[client_id] = user_id
logger.info("WebSocket connected: client_id=%s user_id=%s", client_id, user_id)
await self.send_personal_message(
{
"type": MessageType.INFO,
"message": "WebSocket connected",
"timestamp": self._get_timestamp(),
},
client_id,
)
def disconnect(self, client_id: str):
self.active_connections.pop(client_id, None)
self.connection_users.pop(client_id, None)
logger.info("WebSocket disconnected: client_id=%s", client_id)
async def send_personal_message(self, message: dict, client_id: str):
websocket = self.active_connections.get(client_id)
if not websocket:
return
try:
await websocket.send_text(json.dumps(message, ensure_ascii=False, cls=DateTimeEncoder))
except Exception as exc:
logger.error("Failed to send websocket message to %s: %s", client_id, exc)
self.disconnect(client_id)
async def broadcast(self, message: dict):
if not self.active_connections:
return
tasks = [self.send_personal_message(message, cid) for cid in list(self.active_connections.keys())]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def broadcast_to_user(self, message: dict, user_id: str):
target_clients = [cid for cid, uid in self.connection_users.items() if uid == user_id]
tasks = [self.send_personal_message(message, cid) for cid in target_clients]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
def get_active_connections_count(self) -> int:
return len(self.active_connections)
def get_connected_users(self) -> List[str]:
return list(set(self.connection_users.values()))
def _get_timestamp(self) -> str:
return datetime.now().isoformat()
websocket_manager = WebSocketManager()