diff --git a/src/bidmaster/agents/analysis.py b/src/bidmaster/agents/analysis.py index d7b8ee7..96de6ad 100644 --- a/src/bidmaster/agents/analysis.py +++ b/src/bidmaster/agents/analysis.py @@ -17,6 +17,7 @@ from pydantic import BaseModel, Field from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationItem, DocumentChapter from ..config import get_settings from ..nodes.toc.workflow_utils import should_continue_workflow +from .base import BaseAgentFactory logger = logging.getLogger(__name__) @@ -364,11 +365,10 @@ def finalize_structure_node(state: AnalysisAgentState) -> AnalysisAgentState: # ========== 条件判断函数 ========== -# 使用通用的工作流条件判断函数 -should_continue_processing = should_continue_workflow +# 直接使用通用的工作流条件判断函数 -class AnalysisAgent: +class AnalysisAgent(BaseAgentFactory): """Analysis Agent - 第一阶段分析Agent""" def __init__(self, interaction_handler=None): @@ -394,7 +394,7 @@ class AnalysisAgent: # 添加条件边 workflow.add_conditional_edges( "validate_file", - should_continue_processing, + should_continue_workflow, { "continue": "extract_tables", "end": END @@ -403,7 +403,7 @@ class AnalysisAgent: workflow.add_conditional_edges( "extract_tables", - should_continue_processing, + should_continue_workflow, { "continue": "classify_tables", "end": END @@ -412,7 +412,7 @@ class AnalysisAgent: workflow.add_conditional_edges( "classify_tables", - should_continue_processing, + should_continue_workflow, { "continue": "parse_content", "end": END @@ -421,7 +421,7 @@ class AnalysisAgent: workflow.add_conditional_edges( "parse_content", - should_continue_processing, + should_continue_workflow, { "continue": "generate_toc", "end": END @@ -430,7 +430,7 @@ class AnalysisAgent: workflow.add_conditional_edges( "generate_toc", - should_continue_processing, + should_continue_workflow, { "continue": "finalize_structure", "end": END @@ -507,41 +507,4 @@ class AnalysisAgent: """同步执行接口(用于CLI调用)""" return asyncio.run(self.execute(source_file)) - @classmethod - def create_with_handler(cls, interaction_handler) -> 'AnalysisAgent': - """使用指定的交互处理器创建Agent - - Args: - interaction_handler: 交互处理函数 - - Returns: - AnalysisAgent实例 - """ - return cls(interaction_handler) - - @classmethod - def create_silent(cls) -> 'AnalysisAgent': - """创建静默模式的Agent(使用默认值) - - Returns: - 静默模式的AnalysisAgent实例 - """ - from .interaction import InteractionHandler, InteractionMode - handler = InteractionHandler(mode=InteractionMode.SILENT) - return cls(handler) - - @classmethod - def create_programmatic(cls, presets: Dict[str, Any]) -> 'AnalysisAgent': - """创建程序化模式的Agent(使用预设值) - - Args: - presets: 预设值字典 - - Returns: - 程序化模式的AnalysisAgent实例 - """ - from .interaction import InteractionHandler, InteractionMode - handler = InteractionHandler(mode=InteractionMode.PROGRAMMATIC, presets=presets) - return cls(handler) - diff --git a/src/bidmaster/agents/base.py b/src/bidmaster/agents/base.py index e45d05f..0f91597 100644 --- a/src/bidmaster/agents/base.py +++ b/src/bidmaster/agents/base.py @@ -273,4 +273,48 @@ class BaseAgent: Returns: 错误信息列表 """ - return self.builder.validate_configuration() \ No newline at end of file + return self.builder.validate_configuration() + + +class BaseAgentFactory: + """Agent工厂基类 + + 提供通用的Agent创建方法,避免重复代码。 + """ + + @classmethod + def create_with_handler(cls, interaction_handler): + """使用指定的交互处理器创建Agent + + Args: + interaction_handler: InteractionHandler实例或兼容的交互处理函数 + + Returns: + Agent实例 + """ + return cls(interaction_handler) + + @classmethod + def create_silent(cls): + """创建静默模式的Agent(使用默认值) + + Returns: + 静默模式的Agent实例 + """ + from .interaction import InteractionHandler, InteractionMode + handler = InteractionHandler(mode=InteractionMode.SILENT) + return cls(handler) + + @classmethod + def create_programmatic(cls, presets: Dict[str, Any]): + """创建程序化模式的Agent(使用预设值) + + Args: + presets: 预设值字典 + + Returns: + 程序化模式的Agent实例 + """ + from .interaction import InteractionHandler, InteractionMode + handler = InteractionHandler(mode=InteractionMode.PROGRAMMATIC, presets=presets) + return cls(handler) \ No newline at end of file diff --git a/src/bidmaster/agents/builders/toc_builder.py b/src/bidmaster/agents/builders/toc_builder.py index b489c88..975dd73 100644 --- a/src/bidmaster/agents/builders/toc_builder.py +++ b/src/bidmaster/agents/builders/toc_builder.py @@ -8,7 +8,7 @@ from typing import Dict, Any, Optional, Callable from langgraph.graph import END -from ..base import AgentBuilder, BaseAgent +from ..base import AgentBuilder, BaseAgent, BaseAgentFactory from ...nodes.toc import ( GroupCriteriaNode, GenerateFirstLevelNode, @@ -87,7 +87,7 @@ class TocAgentBuilder(AgentBuilder): self.add_edge("finalize_chapters", "END") -class TocAgent(BaseAgent): +class TocAgent(BaseAgent, BaseAgentFactory): """目录生成Agent 封装了目录生成的完整工作流程。 @@ -102,43 +102,6 @@ class TocAgent(BaseAgent): builder = TocAgentBuilder.create(interaction_handler) super().__init__(builder) - @classmethod - def create_with_handler(cls, interaction_handler) -> 'TocAgent': - """使用指定的交互处理器创建Agent - - Args: - interaction_handler: InteractionHandler实例 - - Returns: - TocAgent实例 - """ - return cls(interaction_handler) - - @classmethod - def create_silent(cls) -> 'TocAgent': - """创建静默模式的Agent(使用默认值) - - Returns: - 静默模式的TocAgent实例 - """ - from ..interaction import InteractionHandler, InteractionMode - handler = InteractionHandler(mode=InteractionMode.SILENT) - return cls(handler) - - @classmethod - def create_programmatic(cls, presets: Dict[str, Any]) -> 'TocAgent': - """创建程序化模式的Agent(使用预设值) - - Args: - presets: 预设值字典 - - Returns: - 程序化模式的TocAgent实例 - """ - from ..interaction import InteractionHandler, InteractionMode - handler = InteractionHandler(mode=InteractionMode.PROGRAMMATIC, presets=presets) - return cls(handler) - async def generate_toc(self, technical_criteria: list, generation_mode: Optional[str] = None, diff --git a/src/bidmaster/nodes/base.py b/src/bidmaster/nodes/base.py index a2fe9c6..111edd1 100644 --- a/src/bidmaster/nodes/base.py +++ b/src/bidmaster/nodes/base.py @@ -116,6 +116,31 @@ class BaseNode(ABC): state.update(kwargs) return state + def _update_progress(self, state: Dict[str, Any], progress: float, **kwargs) -> Dict[str, Any]: + """更新进度状态 + + Args: + state: 当前状态字典 + progress: 进度值 (0.0-1.0) + **kwargs: 其他要更新的状态 + + Returns: + 更新后的状态字典 + """ + kwargs["progress"] = progress + return self._update_state(state, **kwargs) + + def _add_warning(self, state: Dict[str, Any], warning_message: str) -> None: + """添加警告信息 + + Args: + state: 当前状态字典 + warning_message: 警告信息 + """ + warnings = state.setdefault("warnings", []) + warnings.append(warning_message) + logger.warning(f"节点 {self.name}: {warning_message}") + def _handle_execution_error(self, state: Dict[str, Any], error: Exception) -> Dict[str, Any]: """统一的执行错误处理