bidmaster-cli/src/bidmaster/agents/analysis.py
sladro b5b727a61d refactor: 消除技术目录生成Agent中的重复代码
- 统一条件判断函数,移除analysis.py中的重复别名
- 创建BaseAgentFactory基类,抽取共同的工厂方法
- AnalysisAgent和TocAgent继承工厂基类,移除重复代码
- 增强BaseNode状态更新功能,添加进度和警告辅助方法

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-29 11:05:52 +08:00

511 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Analysis Agent - Phase 1 分析阶段Agent
基于LangGraph实现的招标文件分析Agent负责
1. 解析招标文件中的评分表和偏离表
2. 智能分类技术和商务评分项
3. 生成专业的标书章节结构
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, TypedDict
import asyncio
from langgraph.graph import StateGraph, END
from pydantic import BaseModel, Field
from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationItem, DocumentChapter
from ..config import get_settings
from ..nodes.toc.workflow_utils import should_continue_workflow
from .base import BaseAgentFactory
logger = logging.getLogger(__name__)
def _validate_template_file(file_path: str) -> bool:
"""验证模板文件是否有效
Args:
file_path: 文件路径
Returns:
是否为有效的Word模板文件
"""
try:
path = Path(file_path)
return path.exists() and path.suffix.lower() == '.docx'
except Exception:
return False
class AnalysisAgentState(TypedDict):
"""Analysis Agent的状态定义"""
# 输入参数
source_file: str
interaction_handler: Any # InteractionHandler实例
# 执行状态
current_step: str
progress: float # 0.0 - 1.0
should_continue: bool
# 中间数据
raw_tables: List[Dict[str, Any]] # 提取的原始表格数据
classified_tables: Dict[str, List[Dict[str, Any]]] # 分类后的表格 {"scoring": [], "deviation": []}
technical_criteria: List[ScoringCriteria] # 技术评分项
commercial_criteria: List[ScoringCriteria] # 商务评分项
deviation_items: List[DeviationItem] # 偏离项
# 章节生成过程
preliminary_chapters: List[DocumentChapter] # 初步生成的章节
structure_review: Dict[str, Any] # AI审查结果
# 最终输出
bid_structure: BidStructure
# 错误处理
error: str
warnings: List[str]
class AnalysisResult(BaseModel):
"""Analysis Agent的执行结果"""
success: bool = Field(description="是否执行成功")
bid_structure: BidStructure | None = Field(default=None, description="标书结构")
technical_count: int = Field(default=0, description="技术评分项数量")
commercial_count: int = Field(default=0, description="商务评分项数量")
deviation_count: int = Field(default=0, description="偏离项数量")
chapter_count: int = Field(default=0, description="章节数量")
error_message: str | None = Field(default=None, description="错误信息")
warnings: List[str] = Field(default_factory=list, description="警告信息")
execution_time: float = Field(default=0.0, description="执行时间(秒)")
# ========== LangGraph 节点函数 ==========
def validate_file_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点1验证招标文件"""
logger.info("开始验证文件...")
try:
source_file = state["source_file"]
file_path = Path(source_file)
# 检查文件存在性
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {source_file}")
# 检查文件格式
if not source_file.lower().endswith('.docx'):
raise ValueError(f"不支持的文件格式,只支持.docx格式: {source_file}")
# 检查文件大小限制50MB
file_size = file_path.stat().st_size
max_size = 50 * 1024 * 1024 # 50MB
if file_size > max_size:
state["warnings"].append(f"文件较大({file_size/1024/1024:.1f}MB),解析可能较慢")
logger.info(f"文件验证成功: {source_file}")
state["current_step"] = "validate_file"
state["progress"] = 0.1
state["should_continue"] = True
except Exception as e:
logger.error(f"文件验证失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
def extract_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点2从Word文档中提取表格"""
logger.info("开始提取表格...")
try:
from docx import Document
source_file = state["source_file"]
doc = Document(source_file)
raw_tables = []
for i, table in enumerate(doc.tables):
if len(table.rows) < 2: # 至少要有表头和一行数据
continue
# 提取表格文本
from ..tools.parser import BidParser
parser = BidParser()
table_data = {
"index": i,
"row_count": len(table.rows),
"col_count": max(len(row.cells) for row in table.rows) if table.rows else 0,
"text_content": parser.extract_table_text(table)
}
raw_tables.append(table_data)
if not raw_tables:
raise ValueError("文档中未找到有效的表格")
logger.info(f"成功提取{len(raw_tables)}个表格")
state["raw_tables"] = raw_tables
state["current_step"] = "extract_tables"
state["progress"] = 0.25
except Exception as e:
logger.error(f"表格提取失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
def classify_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点3使用AI分类表格类型"""
logger.info("开始分类表格...")
try:
parser = BidParser()
classified_tables = {"scoring": [], "deviation": [], "other": []}
for table_data in state["raw_tables"]:
table_text = table_data["text_content"]
# 使用现有的AI识别逻辑
table_type = parser._identify_table_type(table_text)
table_info = {
"index": table_data["index"],
"type": table_type,
"text_content": table_text,
"row_count": table_data["row_count"],
"col_count": table_data["col_count"]
}
if table_type in classified_tables:
classified_tables[table_type].append(table_info)
else:
classified_tables["other"].append(table_info)
scoring_count = len(classified_tables["scoring"])
deviation_count = len(classified_tables["deviation"])
logger.info(f"表格分类完成: 评分表{scoring_count}个, 偏离表{deviation_count}")
if scoring_count == 0:
state["warnings"].append("未识别到评分表,可能影响结果质量")
state["classified_tables"] = classified_tables
state["current_step"] = "classify_tables"
state["progress"] = 0.4
except Exception as e:
logger.error(f"表格分类失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
def parse_content_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点4解析表格内容提取评分项和偏离项"""
logger.info("开始解析表格内容...")
try:
parser = BidParser()
technical_criteria = []
commercial_criteria = []
deviation_items = []
# 解析评分表
scoring_tables = state["classified_tables"].get("scoring", [])
global_index = 0 # 全局原始索引计数器
for table_info in scoring_tables:
criteria = parser._ai_parse_scoring_table(table_info["text_content"])
if criteria:
# 智能分类技术和商务,同时设置全局原始索引
for criterion in criteria:
criterion.original_index = global_index # 设置全局原始顺序
global_index += 1
if criterion.category.value == "commercial":
commercial_criteria.append(criterion)
else:
technical_criteria.append(criterion)
# 解析偏离表
deviation_tables = state["classified_tables"].get("deviation", [])
for table_info in deviation_tables:
deviations = parser._ai_parse_deviation_table(table_info["text_content"])
if deviations:
deviation_items.extend(deviations)
logger.info(f"内容解析完成: 技术项{len(technical_criteria)}个, 商务项{len(commercial_criteria)}个, 偏离项{len(deviation_items)}")
state["technical_criteria"] = technical_criteria
state["commercial_criteria"] = commercial_criteria
state["deviation_items"] = deviation_items
state["current_step"] = "parse_content"
state["progress"] = 0.7
except Exception as e:
logger.error(f"内容解析失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
def generate_toc_with_agent_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点5使用TocAgent生成目录结构"""
logger.info("开始使用TocAgent生成目录...")
try:
technical_criteria = state["technical_criteria"]
if not technical_criteria:
raise ValueError("缺少技术评分项,无法生成章节结构")
# 从状态中获取或创建InteractionHandler
interaction_handler = state.get("interaction_handler")
if not interaction_handler:
# 如果没有提供InteractionHandler创建默认的交互式处理器
from .interaction import InteractionHandler, InteractionMode
interaction_handler = InteractionHandler(mode=InteractionMode.INTERACTIVE)
# 直接使用AI生成模式
generation_mode = "ai"
template_file = None
# 使用TocAgent新架构生成目录
from .builders.toc_builder import TocAgent
toc_agent = TocAgent(interaction_handler)
result = toc_agent.generate_toc_sync(
technical_criteria=technical_criteria,
generation_mode=generation_mode,
template_file=template_file
)
# 检查是否有错误
if result.get("error"):
raise ValueError(f"目录生成失败: {result['error']}")
# 映射评分项到生成的章节
chapters = result.get("final_chapters", [])
for criteria in technical_criteria:
# 根据类别找到对应章节
category = criteria.category.value
for chapter in chapters:
if "_" in chapter.id:
parts = chapter.id.split("_")
if len(parts) >= 3:
chapter_category = "_".join(parts[2:])
if chapter_category == category:
criteria.chapter_id = chapter.id
break
logger.info(f"目录生成完成: {len(chapters)}个章节")
# 更新状态
state["preliminary_chapters"] = chapters
state["technical_criteria"] = technical_criteria
state["structure_review"] = {} # TocAgent已包含审查
state["current_step"] = "generate_toc"
state["progress"] = 0.85
# 添加警告信息
if result.get("warnings"):
state["warnings"].extend(result.get("warnings", []))
except Exception as e:
logger.error(f"目录生成失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
def finalize_structure_node(state: AnalysisAgentState) -> AnalysisAgentState:
"""节点6最终确定标书结构"""
logger.info("开始最终确定标书结构...")
try:
technical_criteria = state["technical_criteria"]
preliminary_chapters = state["preliminary_chapters"]
structure_review = state.get("structure_review", {})
# 创建最终的标书结构
bid_structure = BidStructure(
project_name=f"标书项目-{Path(state['source_file']).stem}",
scoring_criteria=technical_criteria,
deviation_items=state["deviation_items"],
chapters=preliminary_chapters, # 使用TocAgent生成的章节
scoring_file=state["source_file"]
)
# 保存审查结果(如果有)
if structure_review:
bid_structure.structure_review = structure_review
logger.info(f"标书结构确定完成: {len(preliminary_chapters)}个章节")
state["bid_structure"] = bid_structure
state["current_step"] = "finalize_structure"
state["progress"] = 1.0
state["should_continue"] = False # 完成
except Exception as e:
logger.error(f"标书结构确定失败: {e}")
state["error"] = str(e)
state["should_continue"] = False
return state
# ========== 条件判断函数 ==========
# 直接使用通用的工作流条件判断函数
class AnalysisAgent(BaseAgentFactory):
"""Analysis Agent - 第一阶段分析Agent"""
def __init__(self, interaction_handler=None):
self.settings = get_settings()
self.interaction_handler = interaction_handler
self.graph = self._build_graph()
def _build_graph(self) -> StateGraph:
"""构建LangGraph工作流"""
workflow = StateGraph(AnalysisAgentState)
# 添加节点
workflow.add_node("validate_file", validate_file_node)
workflow.add_node("extract_tables", extract_tables_node)
workflow.add_node("classify_tables", classify_tables_node)
workflow.add_node("parse_content", parse_content_node)
workflow.add_node("generate_toc", generate_toc_with_agent_node) # 使用新的目录生成节点
workflow.add_node("finalize_structure", finalize_structure_node)
# 设置入口点
workflow.set_entry_point("validate_file")
# 添加条件边
workflow.add_conditional_edges(
"validate_file",
should_continue_workflow,
{
"continue": "extract_tables",
"end": END
}
)
workflow.add_conditional_edges(
"extract_tables",
should_continue_workflow,
{
"continue": "classify_tables",
"end": END
}
)
workflow.add_conditional_edges(
"classify_tables",
should_continue_workflow,
{
"continue": "parse_content",
"end": END
}
)
workflow.add_conditional_edges(
"parse_content",
should_continue_workflow,
{
"continue": "generate_toc",
"end": END
}
)
workflow.add_conditional_edges(
"generate_toc",
should_continue_workflow,
{
"continue": "finalize_structure",
"end": END
}
)
workflow.add_edge("finalize_structure", END)
return workflow.compile()
async def execute(self, source_file: str, progress_callback=None) -> AnalysisResult:
"""执行Analysis Agent工作流"""
import time
start_time = time.time()
logger.info(f"开始执行Analysis Agent: {source_file}")
# 初始化状态
initial_state = AnalysisAgentState(
source_file=source_file,
interaction_handler=self.interaction_handler,
current_step="",
progress=0.0,
should_continue=True,
raw_tables=[],
classified_tables={},
technical_criteria=[],
commercial_criteria=[],
deviation_items=[],
preliminary_chapters=[],
structure_review={},
bid_structure=None,
error="",
warnings=[]
)
try:
# 执行LangGraph工作流
final_state = await self.graph.ainvoke(initial_state)
# 构建执行结果
if final_state.get("error"):
result = AnalysisResult(
success=False,
error_message=final_state["error"],
warnings=final_state.get("warnings", []),
execution_time=time.time() - start_time
)
else:
bid_structure = final_state["bid_structure"]
result = AnalysisResult(
success=True,
bid_structure=bid_structure,
technical_count=len(final_state.get("technical_criteria", [])),
commercial_count=len(final_state.get("commercial_criteria", [])),
deviation_count=len(final_state.get("deviation_items", [])),
chapter_count=len(bid_structure.chapters) if bid_structure else 0,
warnings=final_state.get("warnings", []),
execution_time=time.time() - start_time
)
logger.info(f"Analysis Agent执行完成耗时{result.execution_time:.2f}")
return result
except Exception as e:
logger.error(f"Analysis Agent执行异常: {e}")
return AnalysisResult(
success=False,
error_message=str(e),
execution_time=time.time() - start_time
)
def execute_sync(self, source_file: str) -> AnalysisResult:
"""同步执行接口用于CLI调用"""
return asyncio.run(self.execute(source_file))