- 统一条件判断函数,移除analysis.py中的重复别名 - 创建BaseAgentFactory基类,抽取共同的工厂方法 - AnalysisAgent和TocAgent继承工厂基类,移除重复代码 - 增强BaseNode状态更新功能,添加进度和警告辅助方法 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
511 lines
17 KiB
Python
511 lines
17 KiB
Python
"""Analysis Agent - Phase 1 分析阶段Agent
|
||
|
||
基于LangGraph实现的招标文件分析Agent,负责:
|
||
1. 解析招标文件中的评分表和偏离表
|
||
2. 智能分类技术和商务评分项
|
||
3. 生成专业的标书章节结构
|
||
"""
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, TypedDict
|
||
import asyncio
|
||
|
||
from langgraph.graph import StateGraph, END
|
||
from pydantic import BaseModel, Field
|
||
|
||
from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationItem, DocumentChapter
|
||
from ..config import get_settings
|
||
from ..nodes.toc.workflow_utils import should_continue_workflow
|
||
from .base import BaseAgentFactory
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _validate_template_file(file_path: str) -> bool:
|
||
"""验证模板文件是否有效
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
|
||
Returns:
|
||
是否为有效的Word模板文件
|
||
"""
|
||
try:
|
||
path = Path(file_path)
|
||
return path.exists() and path.suffix.lower() == '.docx'
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
class AnalysisAgentState(TypedDict):
|
||
"""Analysis Agent的状态定义"""
|
||
|
||
# 输入参数
|
||
source_file: str
|
||
interaction_handler: Any # InteractionHandler实例
|
||
|
||
# 执行状态
|
||
current_step: str
|
||
progress: float # 0.0 - 1.0
|
||
should_continue: bool
|
||
|
||
# 中间数据
|
||
raw_tables: List[Dict[str, Any]] # 提取的原始表格数据
|
||
classified_tables: Dict[str, List[Dict[str, Any]]] # 分类后的表格 {"scoring": [], "deviation": []}
|
||
technical_criteria: List[ScoringCriteria] # 技术评分项
|
||
commercial_criteria: List[ScoringCriteria] # 商务评分项
|
||
deviation_items: List[DeviationItem] # 偏离项
|
||
|
||
# 章节生成过程
|
||
preliminary_chapters: List[DocumentChapter] # 初步生成的章节
|
||
structure_review: Dict[str, Any] # AI审查结果
|
||
|
||
# 最终输出
|
||
bid_structure: BidStructure
|
||
|
||
# 错误处理
|
||
error: str
|
||
warnings: List[str]
|
||
|
||
|
||
class AnalysisResult(BaseModel):
|
||
"""Analysis Agent的执行结果"""
|
||
|
||
success: bool = Field(description="是否执行成功")
|
||
bid_structure: BidStructure | None = Field(default=None, description="标书结构")
|
||
technical_count: int = Field(default=0, description="技术评分项数量")
|
||
commercial_count: int = Field(default=0, description="商务评分项数量")
|
||
deviation_count: int = Field(default=0, description="偏离项数量")
|
||
chapter_count: int = Field(default=0, description="章节数量")
|
||
error_message: str | None = Field(default=None, description="错误信息")
|
||
warnings: List[str] = Field(default_factory=list, description="警告信息")
|
||
execution_time: float = Field(default=0.0, description="执行时间(秒)")
|
||
|
||
|
||
# ========== LangGraph 节点函数 ==========
|
||
|
||
def validate_file_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点1:验证招标文件"""
|
||
logger.info("开始验证文件...")
|
||
|
||
try:
|
||
source_file = state["source_file"]
|
||
file_path = Path(source_file)
|
||
|
||
# 检查文件存在性
|
||
if not file_path.exists():
|
||
raise FileNotFoundError(f"文件不存在: {source_file}")
|
||
|
||
# 检查文件格式
|
||
if not source_file.lower().endswith('.docx'):
|
||
raise ValueError(f"不支持的文件格式,只支持.docx格式: {source_file}")
|
||
|
||
# 检查文件大小(限制50MB)
|
||
file_size = file_path.stat().st_size
|
||
max_size = 50 * 1024 * 1024 # 50MB
|
||
if file_size > max_size:
|
||
state["warnings"].append(f"文件较大({file_size/1024/1024:.1f}MB),解析可能较慢")
|
||
|
||
logger.info(f"文件验证成功: {source_file}")
|
||
state["current_step"] = "validate_file"
|
||
state["progress"] = 0.1
|
||
state["should_continue"] = True
|
||
|
||
except Exception as e:
|
||
logger.error(f"文件验证失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
def extract_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点2:从Word文档中提取表格"""
|
||
logger.info("开始提取表格...")
|
||
|
||
try:
|
||
from docx import Document
|
||
|
||
source_file = state["source_file"]
|
||
doc = Document(source_file)
|
||
|
||
raw_tables = []
|
||
for i, table in enumerate(doc.tables):
|
||
if len(table.rows) < 2: # 至少要有表头和一行数据
|
||
continue
|
||
|
||
# 提取表格文本
|
||
from ..tools.parser import BidParser
|
||
parser = BidParser()
|
||
table_data = {
|
||
"index": i,
|
||
"row_count": len(table.rows),
|
||
"col_count": max(len(row.cells) for row in table.rows) if table.rows else 0,
|
||
"text_content": parser.extract_table_text(table)
|
||
}
|
||
raw_tables.append(table_data)
|
||
|
||
if not raw_tables:
|
||
raise ValueError("文档中未找到有效的表格")
|
||
|
||
logger.info(f"成功提取{len(raw_tables)}个表格")
|
||
state["raw_tables"] = raw_tables
|
||
state["current_step"] = "extract_tables"
|
||
state["progress"] = 0.25
|
||
|
||
except Exception as e:
|
||
logger.error(f"表格提取失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
def classify_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点3:使用AI分类表格类型"""
|
||
logger.info("开始分类表格...")
|
||
|
||
try:
|
||
parser = BidParser()
|
||
classified_tables = {"scoring": [], "deviation": [], "other": []}
|
||
|
||
for table_data in state["raw_tables"]:
|
||
table_text = table_data["text_content"]
|
||
|
||
# 使用现有的AI识别逻辑
|
||
table_type = parser._identify_table_type(table_text)
|
||
|
||
table_info = {
|
||
"index": table_data["index"],
|
||
"type": table_type,
|
||
"text_content": table_text,
|
||
"row_count": table_data["row_count"],
|
||
"col_count": table_data["col_count"]
|
||
}
|
||
|
||
if table_type in classified_tables:
|
||
classified_tables[table_type].append(table_info)
|
||
else:
|
||
classified_tables["other"].append(table_info)
|
||
|
||
scoring_count = len(classified_tables["scoring"])
|
||
deviation_count = len(classified_tables["deviation"])
|
||
|
||
logger.info(f"表格分类完成: 评分表{scoring_count}个, 偏离表{deviation_count}个")
|
||
|
||
if scoring_count == 0:
|
||
state["warnings"].append("未识别到评分表,可能影响结果质量")
|
||
|
||
state["classified_tables"] = classified_tables
|
||
state["current_step"] = "classify_tables"
|
||
state["progress"] = 0.4
|
||
|
||
except Exception as e:
|
||
logger.error(f"表格分类失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
def parse_content_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点4:解析表格内容,提取评分项和偏离项"""
|
||
logger.info("开始解析表格内容...")
|
||
|
||
try:
|
||
parser = BidParser()
|
||
technical_criteria = []
|
||
commercial_criteria = []
|
||
deviation_items = []
|
||
|
||
# 解析评分表
|
||
scoring_tables = state["classified_tables"].get("scoring", [])
|
||
global_index = 0 # 全局原始索引计数器
|
||
for table_info in scoring_tables:
|
||
criteria = parser._ai_parse_scoring_table(table_info["text_content"])
|
||
if criteria:
|
||
# 智能分类技术和商务,同时设置全局原始索引
|
||
for criterion in criteria:
|
||
criterion.original_index = global_index # 设置全局原始顺序
|
||
global_index += 1
|
||
if criterion.category.value == "commercial":
|
||
commercial_criteria.append(criterion)
|
||
else:
|
||
technical_criteria.append(criterion)
|
||
|
||
# 解析偏离表
|
||
deviation_tables = state["classified_tables"].get("deviation", [])
|
||
for table_info in deviation_tables:
|
||
deviations = parser._ai_parse_deviation_table(table_info["text_content"])
|
||
if deviations:
|
||
deviation_items.extend(deviations)
|
||
|
||
logger.info(f"内容解析完成: 技术项{len(technical_criteria)}个, 商务项{len(commercial_criteria)}个, 偏离项{len(deviation_items)}个")
|
||
|
||
state["technical_criteria"] = technical_criteria
|
||
state["commercial_criteria"] = commercial_criteria
|
||
state["deviation_items"] = deviation_items
|
||
state["current_step"] = "parse_content"
|
||
state["progress"] = 0.7
|
||
|
||
except Exception as e:
|
||
logger.error(f"内容解析失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
def generate_toc_with_agent_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点5:使用TocAgent生成目录结构"""
|
||
logger.info("开始使用TocAgent生成目录...")
|
||
|
||
try:
|
||
technical_criteria = state["technical_criteria"]
|
||
|
||
if not technical_criteria:
|
||
raise ValueError("缺少技术评分项,无法生成章节结构")
|
||
|
||
# 从状态中获取或创建InteractionHandler
|
||
interaction_handler = state.get("interaction_handler")
|
||
if not interaction_handler:
|
||
# 如果没有提供InteractionHandler,创建默认的交互式处理器
|
||
from .interaction import InteractionHandler, InteractionMode
|
||
interaction_handler = InteractionHandler(mode=InteractionMode.INTERACTIVE)
|
||
|
||
# 直接使用AI生成模式
|
||
generation_mode = "ai"
|
||
template_file = None
|
||
|
||
# 使用TocAgent(新架构)生成目录
|
||
from .builders.toc_builder import TocAgent
|
||
toc_agent = TocAgent(interaction_handler)
|
||
result = toc_agent.generate_toc_sync(
|
||
technical_criteria=technical_criteria,
|
||
generation_mode=generation_mode,
|
||
template_file=template_file
|
||
)
|
||
|
||
# 检查是否有错误
|
||
if result.get("error"):
|
||
raise ValueError(f"目录生成失败: {result['error']}")
|
||
|
||
# 映射评分项到生成的章节
|
||
chapters = result.get("final_chapters", [])
|
||
for criteria in technical_criteria:
|
||
# 根据类别找到对应章节
|
||
category = criteria.category.value
|
||
for chapter in chapters:
|
||
if "_" in chapter.id:
|
||
parts = chapter.id.split("_")
|
||
if len(parts) >= 3:
|
||
chapter_category = "_".join(parts[2:])
|
||
if chapter_category == category:
|
||
criteria.chapter_id = chapter.id
|
||
break
|
||
|
||
logger.info(f"目录生成完成: {len(chapters)}个章节")
|
||
|
||
# 更新状态
|
||
state["preliminary_chapters"] = chapters
|
||
state["technical_criteria"] = technical_criteria
|
||
state["structure_review"] = {} # TocAgent已包含审查
|
||
state["current_step"] = "generate_toc"
|
||
state["progress"] = 0.85
|
||
|
||
# 添加警告信息
|
||
if result.get("warnings"):
|
||
state["warnings"].extend(result.get("warnings", []))
|
||
|
||
except Exception as e:
|
||
logger.error(f"目录生成失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
def finalize_structure_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||
"""节点6:最终确定标书结构"""
|
||
logger.info("开始最终确定标书结构...")
|
||
|
||
try:
|
||
technical_criteria = state["technical_criteria"]
|
||
preliminary_chapters = state["preliminary_chapters"]
|
||
structure_review = state.get("structure_review", {})
|
||
|
||
# 创建最终的标书结构
|
||
bid_structure = BidStructure(
|
||
project_name=f"标书项目-{Path(state['source_file']).stem}",
|
||
scoring_criteria=technical_criteria,
|
||
deviation_items=state["deviation_items"],
|
||
chapters=preliminary_chapters, # 使用TocAgent生成的章节
|
||
scoring_file=state["source_file"]
|
||
)
|
||
|
||
# 保存审查结果(如果有)
|
||
if structure_review:
|
||
bid_structure.structure_review = structure_review
|
||
|
||
logger.info(f"标书结构确定完成: {len(preliminary_chapters)}个章节")
|
||
|
||
state["bid_structure"] = bid_structure
|
||
state["current_step"] = "finalize_structure"
|
||
state["progress"] = 1.0
|
||
state["should_continue"] = False # 完成
|
||
|
||
except Exception as e:
|
||
logger.error(f"标书结构确定失败: {e}")
|
||
state["error"] = str(e)
|
||
state["should_continue"] = False
|
||
|
||
return state
|
||
|
||
|
||
# ========== 条件判断函数 ==========
|
||
|
||
# 直接使用通用的工作流条件判断函数
|
||
|
||
|
||
class AnalysisAgent(BaseAgentFactory):
|
||
"""Analysis Agent - 第一阶段分析Agent"""
|
||
|
||
def __init__(self, interaction_handler=None):
|
||
self.settings = get_settings()
|
||
self.interaction_handler = interaction_handler
|
||
self.graph = self._build_graph()
|
||
|
||
def _build_graph(self) -> StateGraph:
|
||
"""构建LangGraph工作流"""
|
||
workflow = StateGraph(AnalysisAgentState)
|
||
|
||
# 添加节点
|
||
workflow.add_node("validate_file", validate_file_node)
|
||
workflow.add_node("extract_tables", extract_tables_node)
|
||
workflow.add_node("classify_tables", classify_tables_node)
|
||
workflow.add_node("parse_content", parse_content_node)
|
||
workflow.add_node("generate_toc", generate_toc_with_agent_node) # 使用新的目录生成节点
|
||
workflow.add_node("finalize_structure", finalize_structure_node)
|
||
|
||
# 设置入口点
|
||
workflow.set_entry_point("validate_file")
|
||
|
||
# 添加条件边
|
||
workflow.add_conditional_edges(
|
||
"validate_file",
|
||
should_continue_workflow,
|
||
{
|
||
"continue": "extract_tables",
|
||
"end": END
|
||
}
|
||
)
|
||
|
||
workflow.add_conditional_edges(
|
||
"extract_tables",
|
||
should_continue_workflow,
|
||
{
|
||
"continue": "classify_tables",
|
||
"end": END
|
||
}
|
||
)
|
||
|
||
workflow.add_conditional_edges(
|
||
"classify_tables",
|
||
should_continue_workflow,
|
||
{
|
||
"continue": "parse_content",
|
||
"end": END
|
||
}
|
||
)
|
||
|
||
workflow.add_conditional_edges(
|
||
"parse_content",
|
||
should_continue_workflow,
|
||
{
|
||
"continue": "generate_toc",
|
||
"end": END
|
||
}
|
||
)
|
||
|
||
workflow.add_conditional_edges(
|
||
"generate_toc",
|
||
should_continue_workflow,
|
||
{
|
||
"continue": "finalize_structure",
|
||
"end": END
|
||
}
|
||
)
|
||
|
||
workflow.add_edge("finalize_structure", END)
|
||
|
||
return workflow.compile()
|
||
|
||
async def execute(self, source_file: str, progress_callback=None) -> AnalysisResult:
|
||
"""执行Analysis Agent工作流"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
logger.info(f"开始执行Analysis Agent: {source_file}")
|
||
|
||
# 初始化状态
|
||
initial_state = AnalysisAgentState(
|
||
source_file=source_file,
|
||
interaction_handler=self.interaction_handler,
|
||
current_step="",
|
||
progress=0.0,
|
||
should_continue=True,
|
||
raw_tables=[],
|
||
classified_tables={},
|
||
technical_criteria=[],
|
||
commercial_criteria=[],
|
||
deviation_items=[],
|
||
preliminary_chapters=[],
|
||
structure_review={},
|
||
bid_structure=None,
|
||
error="",
|
||
warnings=[]
|
||
)
|
||
|
||
try:
|
||
# 执行LangGraph工作流
|
||
final_state = await self.graph.ainvoke(initial_state)
|
||
|
||
# 构建执行结果
|
||
if final_state.get("error"):
|
||
result = AnalysisResult(
|
||
success=False,
|
||
error_message=final_state["error"],
|
||
warnings=final_state.get("warnings", []),
|
||
execution_time=time.time() - start_time
|
||
)
|
||
else:
|
||
bid_structure = final_state["bid_structure"]
|
||
result = AnalysisResult(
|
||
success=True,
|
||
bid_structure=bid_structure,
|
||
technical_count=len(final_state.get("technical_criteria", [])),
|
||
commercial_count=len(final_state.get("commercial_criteria", [])),
|
||
deviation_count=len(final_state.get("deviation_items", [])),
|
||
chapter_count=len(bid_structure.chapters) if bid_structure else 0,
|
||
warnings=final_state.get("warnings", []),
|
||
execution_time=time.time() - start_time
|
||
)
|
||
|
||
logger.info(f"Analysis Agent执行完成,耗时{result.execution_time:.2f}秒")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Analysis Agent执行异常: {e}")
|
||
return AnalysisResult(
|
||
success=False,
|
||
error_message=str(e),
|
||
execution_time=time.time() - start_time
|
||
)
|
||
|
||
def execute_sync(self, source_file: str) -> AnalysisResult:
|
||
"""同步执行接口(用于CLI调用)"""
|
||
return asyncio.run(self.execute(source_file))
|
||
|
||
|