CadHubManage/app/core/cad_batch_manager.py
sladro 08623bf4d6 feat: Enhance SerialBatchExecutor with pre-batch cleanup and task execution improvements
- Added pre-batch cleanup functionality to SerialBatchExecutor, allowing for cleanup tasks before processing items.
- Introduced new task execution phases and improved error handling for task submissions.
- Implemented inter-step delays and between-items delays for better task management.
- Updated logging to capture detailed events during batch processing.
- Enhanced configuration options for plugins in software_config.yaml to support new features.
- Added tests for pre-batch cleanup and auto-close scenarios to ensure robust handling of edge cases.
- Created a PowerShell script for automated callback handling from Revit.
2026-03-03 16:13:19 +08:00

517 lines
19 KiB
Python

from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime
from typing import Dict, List, Optional, Set
from app.config import settings, software_config
from app.core.cad_task_router import CadTaskRouter, RouteNotFoundError
from app.core.plugin_callback_registry import PluginCallbackRegistry
from app.core.plugin_http_client import PluginHttpClient
from app.core.serial_batch_executor import SerialBatchExecutor
from app.models.cad_batch import (
BatchItem,
BatchItemStatus,
BatchJob,
BatchStatus,
BatchSubmitRequest,
PluginCallbackPayload,
)
from app.models.operation_log import OperationStatus
logger = logging.getLogger(__name__)
TERMINAL_BATCH_STATUSES: Set[BatchStatus] = {
BatchStatus.COMPLETED,
BatchStatus.COMPLETED_WITH_ERRORS,
BatchStatus.FAILED,
}
class CadBatchManager:
def __init__(
self,
task_router: Optional[CadTaskRouter] = None,
plugin_client: Optional[PluginHttpClient] = None,
callback_registry: Optional[PluginCallbackRegistry] = None,
):
self._router = task_router or CadTaskRouter()
self._plugin_client = plugin_client or PluginHttpClient()
self._callback_registry = callback_registry or PluginCallbackRegistry()
self._executor = SerialBatchExecutor(self, self._plugin_client, self._callback_registry)
self._batches: Dict[str, BatchJob] = {}
self._items: Dict[str, BatchItem] = {}
self._websocket_manager = None
self._log_manager = None
self._lock = asyncio.Lock()
def set_websocket_manager(self, websocket_manager):
self._websocket_manager = websocket_manager
def set_log_manager(self, log_manager):
self._log_manager = log_manager
async def start(self):
await self._executor.start()
async def stop(self):
await self._executor.stop()
await self._callback_registry.clear()
async def create_batch(self, request: BatchSubmitRequest, submitter_id: Optional[str] = None) -> BatchJob:
batch_id = str(uuid.uuid4())
batch = BatchJob(
id=batch_id,
name=request.batch_name,
submitter_id=submitter_id,
metadata=request.metadata,
)
enqueue_items: List[str] = []
async with self._lock:
self._batches[batch_id] = batch
for idx, submit_item in enumerate(request.items):
item_id = str(uuid.uuid4())
status = BatchItemStatus.QUEUED
file_extension = None
software_id = None
max_retries = 0
error_message = None
try:
file_extension, software_id = self._router.resolve(submit_item.model_path)
plugin_cfg = software_config.get_plugin_config(software_id) or {}
max_retries = int(plugin_cfg.get("max_retries", 0))
except RouteNotFoundError as exc:
status = BatchItemStatus.FAILED
error_message = str(exc)
item = BatchItem(
id=item_id,
batch_id=batch_id,
sequence=idx,
model_path=submit_item.model_path,
file_extension=file_extension,
software_id=software_id,
task_type=submit_item.task_type,
task_params=submit_item.task_params,
status=status,
max_retries=max_retries,
error_message=error_message,
finished_at=datetime.now() if status == BatchItemStatus.FAILED else None,
)
self._items[item_id] = item
batch.item_ids.append(item_id)
if status == BatchItemStatus.QUEUED:
enqueue_items.append(item_id)
self._recalculate_batch(batch_id)
batch_snapshot = self._batches[batch_id].model_copy(deep=True)
for item_id in enqueue_items:
await self._executor.enqueue(item_id)
await self._notify_websocket(
"batch_created",
{
"batch": batch_snapshot.model_dump(mode="json"),
"items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_id)],
},
)
await self._log_system_operation(
operation="batch_create",
details=f"Created batch {batch_id} with {batch_snapshot.total_count} items",
status=OperationStatus.SUCCESS,
extra_data={
"batch_id": batch_id,
"submitter_id": submitter_id,
"total_count": batch_snapshot.total_count,
},
)
return batch_snapshot
async def get_batch(self, batch_id: str) -> Optional[BatchJob]:
async with self._lock:
batch = self._batches.get(batch_id)
return batch.model_copy(deep=True) if batch else None
async def get_item(self, item_id: str) -> Optional[BatchItem]:
async with self._lock:
item = self._items.get(item_id)
return item.model_copy(deep=True) if item else None
async def get_batch_items(self, batch_id: str) -> List[BatchItem]:
async with self._lock:
batch = self._batches.get(batch_id)
if not batch:
return []
return [self._items[item_id].model_copy(deep=True) for item_id in batch.item_ids if item_id in self._items]
def get_callback_endpoint_url(self) -> str:
base = getattr(settings, "batch_callback_base_url", None) or f"http://localhost:{settings.port}"
return f"{base.rstrip('/')}/api/v1/plugin-callbacks/task-result"
def get_callback_timeout_sec(self, software_id: str) -> int:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
return int(plugin_cfg.get("callback_timeout_sec", 60))
def get_retry_backoff_sec(self, software_id: str, attempt_idx: int) -> int:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
backoff = plugin_cfg.get("retry_backoff_sec", [1, 3, 5])
if not isinstance(backoff, list) or not backoff:
return 0
idx = min(attempt_idx, len(backoff) - 1)
try:
return int(backoff[idx])
except (TypeError, ValueError):
return 0
def get_task_completion_mode(self, software_id: str, task_type: str) -> str:
task_cfg = software_config.get_plugin_task_config(software_id, task_type) or {}
mode = task_cfg.get("completion_mode", "callback")
if mode not in {"callback", "sync", "submit_only"}:
return "callback"
return mode
def should_auto_open_model(self, software_id: str, task_type: str) -> bool:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
enabled = bool(plugin_cfg.get("auto_open_model_before_tasks", False))
if not enabled:
return False
exclude = plugin_cfg.get("auto_open_exclude_task_types", ["open_model", "close_model"])
if isinstance(exclude, list) and task_type in exclude:
return False
include = plugin_cfg.get("auto_open_include_task_types")
if isinstance(include, list) and include:
return task_type in include
return True
def should_auto_close_model(self, software_id: str, task_type: str) -> bool:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
enabled = bool(plugin_cfg.get("auto_close_model_after_tasks", False))
if not enabled:
return False
exclude = plugin_cfg.get("auto_close_exclude_task_types", ["open_model", "close_model"])
if isinstance(exclude, list) and task_type in exclude:
return False
include = plugin_cfg.get("auto_close_include_task_types")
if isinstance(include, list) and include:
return task_type in include
return True
def get_inter_step_delay_sec(self, software_id: str) -> float:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
value = plugin_cfg.get("inter_step_delay_sec", 0)
try:
delay = float(value)
except (TypeError, ValueError):
return 0
return max(0.0, delay)
def get_between_items_delay_sec(self, software_id: str) -> float:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
value = plugin_cfg.get("between_items_delay_sec", 0)
try:
delay = float(value)
except (TypeError, ValueError):
return 0
return max(0.0, delay)
def should_run_pre_batch_cleanup(self, software_id: str) -> bool:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
return bool(plugin_cfg.get("pre_batch_cleanup_enabled", False))
def get_pre_batch_cleanup_task_type(self, software_id: str) -> Optional[str]:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
task_type = plugin_cfg.get("pre_batch_cleanup_task_type")
if isinstance(task_type, str) and task_type.strip():
return task_type.strip()
return None
def get_pre_batch_cleanup_task_params(self, software_id: str) -> Dict[str, object]:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
params = plugin_cfg.get("pre_batch_cleanup_task_params")
return params if isinstance(params, dict) else {}
def get_pre_batch_cleanup_ignore_error_markers(self, software_id: str) -> List[str]:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
markers = plugin_cfg.get("pre_batch_cleanup_ignore_error_markers")
if not isinstance(markers, list):
return []
normalized: List[str] = []
for marker in markers:
if isinstance(marker, str):
value = marker.strip()
if value:
normalized.append(value)
return normalized
def get_close_model_ignore_error_markers(self, software_id: str) -> List[str]:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
markers = plugin_cfg.get("close_model_ignore_error_markers")
if not isinstance(markers, list):
return []
normalized: List[str] = []
for marker in markers:
if isinstance(marker, str):
value = marker.strip()
if value:
normalized.append(value)
return normalized
def should_check_open_before_pre_batch_cleanup(self, software_id: str) -> bool:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
return bool(plugin_cfg.get("pre_batch_check_open_enabled", False))
def get_pre_batch_check_task_type(self, software_id: str) -> Optional[str]:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
task_type = plugin_cfg.get("pre_batch_check_task_type")
if isinstance(task_type, str) and task_type.strip():
return task_type.strip()
return None
async def log_execution_event(
self,
operation: str,
details: str,
status: OperationStatus = OperationStatus.SUCCESS,
extra_data: Optional[dict] = None,
):
await self._log_system_operation(
operation=operation,
details=details,
status=status,
extra_data=extra_data or {},
)
def validate_callback_token(self, software_id: str, token: Optional[str]) -> bool:
plugin_cfg = software_config.get_plugin_config(software_id) or {}
expected = plugin_cfg.get("callback_token")
if not expected:
return False
return token == expected
async def handle_plugin_callback(self, payload: PluginCallbackPayload) -> bool:
accepted = await self._callback_registry.handle_callback(payload)
if not accepted:
await self._log_system_operation(
operation="plugin_callback_duplicate",
details=f"Duplicate callback ignored for execution_id={payload.execution_id}",
status=OperationStatus.SUCCESS,
extra_data={"execution_id": payload.execution_id, "software_id": payload.software_id},
)
return accepted
async def mark_item_dispatching(self, item_id: str, execution_id: str, attempt: int):
await self._update_item(
item_id,
{
"status": BatchItemStatus.DISPATCHING,
"execution_id": execution_id,
"retry_count": attempt,
"started_at": datetime.now(),
"error_message": None,
},
)
async def mark_item_waiting_callback(self, item_id: str, execution_id: str, plugin_response: dict):
await self._update_item(
item_id,
{
"status": BatchItemStatus.WAITING_CALLBACK,
"execution_id": execution_id,
"plugin_response": plugin_response or {},
},
)
async def mark_item_retrying(self, item_id: str, error_message: str, next_attempt: int):
await self._update_item(
item_id,
{
"status": BatchItemStatus.QUEUED,
"retry_count": next_attempt,
"error_message": error_message,
},
)
async def mark_item_succeeded(self, item_id: str, result: dict):
await self._update_item(
item_id,
{
"status": BatchItemStatus.SUCCEEDED,
"result": result or {},
"finished_at": datetime.now(),
"error_message": None,
},
)
async def mark_item_failed(self, item_id: str, error_message: str):
await self._update_item(
item_id,
{
"status": BatchItemStatus.FAILED,
"error_message": error_message,
"finished_at": datetime.now(),
},
)
async def _update_item(self, item_id: str, changes: dict):
batch_snapshot = None
item_snapshot = None
emit_batch_completed = False
async with self._lock:
item = self._items.get(item_id)
if not item:
return
batch = self._batches.get(item.batch_id)
if not batch:
return
prev_batch_status = batch.status
for key, value in changes.items():
setattr(item, key, value)
self._recalculate_batch(item.batch_id)
item_snapshot = item.model_copy(deep=True)
batch_snapshot = self._batches[item.batch_id].model_copy(deep=True)
emit_batch_completed = (
prev_batch_status not in TERMINAL_BATCH_STATUSES
and batch_snapshot.status in TERMINAL_BATCH_STATUSES
)
await self._notify_websocket(
"batch_item_update",
{
"item": item_snapshot.model_dump(mode="json"),
"batch": batch_snapshot.model_dump(mode="json"),
},
)
if item_snapshot.status == BatchItemStatus.FAILED:
op_status = OperationStatus.FAILED
else:
op_status = OperationStatus.SUCCESS
await self._log_system_operation(
operation="batch_item_update",
details=f"Item {item_snapshot.id} status changed to {item_snapshot.status.value}",
status=op_status,
extra_data={
"batch_id": batch_snapshot.id,
"item_id": item_snapshot.id,
"status": item_snapshot.status.value,
"software_id": item_snapshot.software_id,
"error_message": item_snapshot.error_message,
},
)
if emit_batch_completed:
await self._notify_websocket(
"batch_completed",
{
"batch": batch_snapshot.model_dump(mode="json"),
"items": [i.model_dump(mode="json") for i in await self.get_batch_items(batch_snapshot.id)],
},
)
def _recalculate_batch(self, batch_id: str):
batch = self._batches[batch_id]
items = [self._items[item_id] for item_id in batch.item_ids if item_id in self._items]
total = len(items)
queued = sum(1 for i in items if i.status == BatchItemStatus.QUEUED)
running = sum(
1
for i in items
if i.status in {BatchItemStatus.DISPATCHING, BatchItemStatus.WAITING_CALLBACK}
)
succeeded = sum(1 for i in items if i.status == BatchItemStatus.SUCCEEDED)
failed = sum(1 for i in items if i.status == BatchItemStatus.FAILED)
batch.total_count = total
batch.queued_count = queued
batch.running_count = running
batch.succeeded_count = succeeded
batch.failed_count = failed
if running > 0:
batch.status = BatchStatus.RUNNING
if batch.started_at is None:
batch.started_at = datetime.now()
batch.finished_at = None
return
if queued > 0:
batch.status = BatchStatus.PENDING
return
if total == 0:
batch.status = BatchStatus.FAILED
batch.finished_at = datetime.now()
return
if failed == total:
batch.status = BatchStatus.FAILED
elif failed > 0:
batch.status = BatchStatus.COMPLETED_WITH_ERRORS
else:
batch.status = BatchStatus.COMPLETED
if batch.started_at is None:
batch.started_at = datetime.now()
batch.finished_at = datetime.now()
async def _notify_websocket(self, message_type: str, data: dict):
if self._websocket_manager:
await self._websocket_manager.broadcast(
{
"type": message_type,
"data": data,
"timestamp": datetime.now().isoformat(),
}
)
async def _log_system_operation(self, operation: str, details: str, status=OperationStatus.SUCCESS, **kwargs):
extra_data = kwargs.get("extra_data")
prefix = "[batch]"
if status == OperationStatus.FAILED:
logger.warning("%s %s | %s | extra=%s", prefix, operation, details, extra_data or {})
else:
logger.info("%s %s | %s | extra=%s", prefix, operation, details, extra_data or {})
if self._log_manager:
await self._log_manager.log_system_operation(
operation=operation,
details=details,
status=status,
operation_category="batch_processing",
**kwargs,
)
cad_batch_manager = CadBatchManager()