feat: add context-aware toc automation
This commit is contained in:
parent
6f785c9f2c
commit
e028e4fa96
@ -113,11 +113,14 @@ toc_prompts:
|
||||
【大类别】: {parent_title}
|
||||
【评分项】:
|
||||
{criteria_info}
|
||||
【招标上下文摘录】:
|
||||
{context_snippets}
|
||||
|
||||
生成要求:
|
||||
1. 为每个评分项生成对应的子标题名称(不要包含编号)
|
||||
2. 重要评分项可添加三级子标题(不要包含编号)
|
||||
3. 只返回标题文本,编号由Word自动管理
|
||||
3. 充分参考上下文摘录中的专业术语和业务背景,体现定制化判断
|
||||
4. 只返回标题文本,编号由Word自动管理
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
@ -137,11 +140,14 @@ toc_prompts:
|
||||
|
||||
【当前生成的章节结构】:
|
||||
{chapters_summary}
|
||||
【招标主题线索】:
|
||||
{document_themes}
|
||||
|
||||
【审查要求】:
|
||||
1. 是否缺少重要的标准章节?
|
||||
2. 章节顺序是否合理?
|
||||
3. 每个评分项是否都有对应章节?
|
||||
3. 是否覆盖了主题线索中的关键内容?
|
||||
4. 每个评分项是否都有对应章节?
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
|
||||
@ -18,6 +18,7 @@ from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationIt
|
||||
from ..config import get_settings
|
||||
from ..nodes.toc.base_mixins import WorkflowUtilsMixin
|
||||
from .base import BaseAgentFactory
|
||||
from ..utils.document_context import DocumentContext, DocumentContextBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -60,9 +61,15 @@ class AnalysisAgentState(TypedDict):
|
||||
# 章节生成过程
|
||||
preliminary_chapters: List[DocumentChapter] # 初步生成的章节
|
||||
structure_review: Dict[str, Any] # AI审查结果
|
||||
document_context: DocumentContext | None # 招标文档上下文
|
||||
auto_mode: bool # 是否启用自动目录优化
|
||||
user_feedback: str # 自动/人工反馈内容
|
||||
pending_suggestions: List[Dict[str, Any]] # 待处理的AI建议
|
||||
auto_optimization_rounds: int # 自动优化迭代次数
|
||||
auto_toc_max_rounds: int # 最大自动优化次数
|
||||
|
||||
# 最终输出
|
||||
bid_structure: BidStructure
|
||||
bid_structure: BidStructure | None
|
||||
|
||||
# 错误处理
|
||||
error: str
|
||||
@ -154,6 +161,14 @@ def extract_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
|
||||
state["current_step"] = "extract_tables"
|
||||
state["progress"] = 0.25
|
||||
|
||||
# 构建全文上下文供后续目录生成使用
|
||||
try:
|
||||
context_builder = DocumentContextBuilder()
|
||||
document_context = context_builder.build(source_file)
|
||||
state["document_context"] = document_context
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建招标文档上下文失败: {exc}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表格提取失败: {e}")
|
||||
state["error"] = str(e)
|
||||
@ -284,7 +299,9 @@ def generate_toc_with_agent_node(state: AnalysisAgentState) -> AnalysisAgentStat
|
||||
result = toc_agent.generate_toc_sync(
|
||||
technical_criteria=technical_criteria,
|
||||
generation_mode=generation_mode,
|
||||
template_file=template_file
|
||||
template_file=template_file,
|
||||
document_context=state.get("document_context"),
|
||||
auto_mode=state.get("auto_mode")
|
||||
)
|
||||
|
||||
# 检查是否有错误
|
||||
@ -475,6 +492,12 @@ class AnalysisAgent(BaseAgentFactory):
|
||||
deviation_items=[],
|
||||
preliminary_chapters=[],
|
||||
structure_review={},
|
||||
document_context=None,
|
||||
auto_mode=self.settings.auto_toc_mode,
|
||||
user_feedback="",
|
||||
pending_suggestions=[],
|
||||
auto_optimization_rounds=0,
|
||||
auto_toc_max_rounds=self.settings.auto_toc_max_rounds,
|
||||
bid_structure=None,
|
||||
error="",
|
||||
warnings=[]
|
||||
@ -482,7 +505,10 @@ class AnalysisAgent(BaseAgentFactory):
|
||||
|
||||
try:
|
||||
# 执行LangGraph工作流
|
||||
final_state = await self.graph.ainvoke(initial_state)
|
||||
final_state = await self.graph.ainvoke(
|
||||
initial_state,
|
||||
config={"recursion_limit": self.settings.langgraph_recursion_limit}
|
||||
)
|
||||
|
||||
# 构建执行结果
|
||||
if final_state.get("error"):
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import List, Dict, Any, Optional, Callable, Tuple
|
||||
from langgraph.graph import StateGraph, END
|
||||
|
||||
from ..nodes.base import BaseNode, NodeContext
|
||||
from ..config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -227,6 +228,8 @@ class BaseAgent:
|
||||
"""
|
||||
self.builder = builder
|
||||
self.graph = builder.build()
|
||||
self.settings = get_settings()
|
||||
self.recursion_limit = getattr(self.settings, "langgraph_recursion_limit", 120)
|
||||
|
||||
async def execute(self, initial_state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""异步执行Agent
|
||||
@ -240,7 +243,10 @@ class BaseAgent:
|
||||
logger.info(f"开始执行Agent,初始状态keys: {list(initial_state.keys())}")
|
||||
|
||||
try:
|
||||
final_state = await self.graph.ainvoke(initial_state)
|
||||
final_state = await self.graph.ainvoke(
|
||||
initial_state,
|
||||
config={"recursion_limit": self.recursion_limit}
|
||||
)
|
||||
logger.info("Agent执行完成")
|
||||
return final_state
|
||||
except Exception as e:
|
||||
|
||||
@ -145,7 +145,9 @@ class TocAgent(BaseAgent, BaseAgentFactory):
|
||||
async def generate_toc(self,
|
||||
technical_criteria: list,
|
||||
generation_mode: Optional[str] = None,
|
||||
template_file: Optional[str] = None) -> Dict[str, Any]:
|
||||
template_file: Optional[str] = None,
|
||||
document_context: Optional[Any] = None,
|
||||
auto_mode: Optional[bool] = None) -> Dict[str, Any]:
|
||||
"""生成目录结构
|
||||
|
||||
Args:
|
||||
@ -165,7 +167,10 @@ class TocAgent(BaseAgent, BaseAgentFactory):
|
||||
"structure_review": {},
|
||||
"final_chapters": [],
|
||||
"error": "",
|
||||
"warnings": []
|
||||
"warnings": [],
|
||||
"document_context": document_context,
|
||||
"auto_mode": True if auto_mode is None else auto_mode,
|
||||
"user_feedback": ""
|
||||
}
|
||||
|
||||
# 如果提供了预设参数,添加到状态中
|
||||
@ -182,7 +187,9 @@ class TocAgent(BaseAgent, BaseAgentFactory):
|
||||
def generate_toc_sync(self,
|
||||
technical_criteria: list,
|
||||
generation_mode: Optional[str] = None,
|
||||
template_file: Optional[str] = None) -> Dict[str, Any]:
|
||||
template_file: Optional[str] = None,
|
||||
document_context: Optional[Any] = None,
|
||||
auto_mode: Optional[bool] = None) -> Dict[str, Any]:
|
||||
"""同步生成目录结构
|
||||
|
||||
Args:
|
||||
@ -195,5 +202,5 @@ class TocAgent(BaseAgent, BaseAgentFactory):
|
||||
"""
|
||||
import asyncio
|
||||
return asyncio.run(
|
||||
self.generate_toc(technical_criteria, generation_mode, template_file)
|
||||
self.generate_toc(technical_criteria, generation_mode, template_file, document_context, auto_mode)
|
||||
)
|
||||
@ -52,6 +52,11 @@ class Settings(BaseSettings):
|
||||
embedding_model: str = Field(default="text-embedding-3-small", description="嵌入模型")
|
||||
chunk_size: int = Field(default=1000, description="文档块大小")
|
||||
chunk_overlap: int = Field(default=200, description="块重叠大小")
|
||||
document_context_chunk_size: int = Field(default=800, description="目录生成上下文分块大小")
|
||||
document_context_overlap: int = Field(default=150, description="目录生成上下文块重叠")
|
||||
document_context_top_k: int = Field(default=3, description="上下文检索结果数量")
|
||||
langgraph_recursion_limit: int = Field(default=120, description="LangGraph执行的递归上限")
|
||||
auto_toc_max_rounds: int = Field(default=2, description="自动目录优化最大迭代次数")
|
||||
|
||||
# 性能配置
|
||||
max_workers: int = Field(default=4, description="最大工作线程数")
|
||||
@ -59,6 +64,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# 交互配置
|
||||
interaction_timeout: int = Field(default=60, description="用户交互超时时间(秒)")
|
||||
auto_toc_mode: bool = Field(default=True, description="目录生成是否默认自动处理反馈")
|
||||
|
||||
# 日志配置
|
||||
log_level: str = Field(default="INFO", description="日志级别")
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from ..base import BaseNode, NodeContext
|
||||
from ...tools.parser import ScoringCriteria, DocumentChapter
|
||||
@ -12,6 +12,7 @@ from .category_manager import CategoryManager
|
||||
from .factories import ChapterFactory
|
||||
from .llm_helper import LLMHelper
|
||||
from .base_mixins import TocNodeBase
|
||||
from ...utils.document_context import DocumentContextSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,11 +41,23 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
|
||||
|
||||
preliminary_chapters = state["preliminary_chapters"]
|
||||
technical_criteria = state["technical_criteria"]
|
||||
document_context = state.get("document_context")
|
||||
|
||||
context_searcher: Optional[DocumentContextSearcher] = None
|
||||
if document_context:
|
||||
try:
|
||||
context_searcher = DocumentContextSearcher(document_context)
|
||||
except Exception as exc:
|
||||
self.log_step_info("context_searcher", f"上下文检索器初始化失败: {exc}")
|
||||
|
||||
# 为每个章节生成子标题
|
||||
enhanced_chapters = []
|
||||
for chapter in preliminary_chapters:
|
||||
enhanced_chapter = self._enhance_chapter_with_subs(chapter, technical_criteria)
|
||||
enhanced_chapter = self._enhance_chapter_with_subs(
|
||||
chapter,
|
||||
technical_criteria,
|
||||
context_searcher
|
||||
)
|
||||
enhanced_chapters.append(enhanced_chapter)
|
||||
|
||||
self.log_step_info("generate_sub_chapters", f"完成{len(enhanced_chapters)}个章节的子标题生成")
|
||||
@ -53,7 +66,8 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
|
||||
|
||||
def _enhance_chapter_with_subs(self,
|
||||
chapter: DocumentChapter,
|
||||
technical_criteria: List[ScoringCriteria]) -> DocumentChapter:
|
||||
technical_criteria: List[ScoringCriteria],
|
||||
context_searcher: Optional[DocumentContextSearcher] = None) -> DocumentChapter:
|
||||
"""为章节增强子标题
|
||||
|
||||
Args:
|
||||
@ -71,8 +85,14 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
|
||||
self.log_step_info("enhance_chapter", warning_msg)
|
||||
return chapter
|
||||
|
||||
context_snippets = self._collect_context_snippets(chapter, corresponding_criteria, context_searcher)
|
||||
|
||||
# 使用AI生成子标题
|
||||
sub_chapters_data = LLMHelper.generate_sub_chapters_ai(corresponding_criteria, chapter)
|
||||
sub_chapters_data = LLMHelper.generate_sub_chapters_ai(
|
||||
corresponding_criteria,
|
||||
chapter,
|
||||
context_snippets=context_snippets
|
||||
)
|
||||
|
||||
if sub_chapters_data:
|
||||
chapter.children = ChapterFactory.create_chapters_from_ai_response(chapter, sub_chapters_data)
|
||||
@ -82,4 +102,34 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
|
||||
self.log_step_info("enhance_chapter", f"章节 {chapter.title} AI生成子标题失败")
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return chapter
|
||||
return chapter
|
||||
|
||||
def _collect_context_snippets(self,
|
||||
chapter: DocumentChapter,
|
||||
criteria: List[ScoringCriteria],
|
||||
context_searcher: Optional[DocumentContextSearcher]) -> Optional[List[str]]:
|
||||
if not context_searcher or not criteria:
|
||||
return None
|
||||
|
||||
query_parts = [chapter.title]
|
||||
for item in criteria:
|
||||
query_parts.append(item.item_name)
|
||||
if item.description:
|
||||
query_parts.append(item.description)
|
||||
|
||||
query = ";".join(part for part in query_parts if part)
|
||||
if not query:
|
||||
return None
|
||||
|
||||
matches = context_searcher.search(query)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
snippets: List[str] = []
|
||||
for match in matches:
|
||||
snippet_text = match.text.strip().replace("\n", " ")
|
||||
if not snippet_text:
|
||||
continue
|
||||
snippets.append(f"{match.section}: {snippet_text[:200]}")
|
||||
|
||||
return snippets or None
|
||||
@ -80,12 +80,14 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def generate_sub_chapters_ai(criteria_list: List[ScoringCriteria],
|
||||
parent_chapter: DocumentChapter) -> Optional[List[Dict[str, Any]]]:
|
||||
parent_chapter: DocumentChapter,
|
||||
context_snippets: Optional[List[str]] = None) -> Optional[List[Dict[str, Any]]]:
|
||||
"""AI生成子章节数据
|
||||
|
||||
Args:
|
||||
criteria_list: 对应的评分项列表
|
||||
parent_chapter: 父章节
|
||||
context_snippets: 招标文档的相关片段
|
||||
|
||||
Returns:
|
||||
子章节数据列表,失败时返回None
|
||||
@ -97,10 +99,13 @@ class LLMHelper:
|
||||
|
||||
# 从配置获取提示词
|
||||
prompt_manager = get_prompt_manager()
|
||||
context_text = "\n".join(context_snippets) if context_snippets else "(暂无上下文摘录)"
|
||||
|
||||
prompt = prompt_manager.get_toc_prompt(
|
||||
"generate_sub_chapters",
|
||||
parent_title=parent_chapter.title,
|
||||
criteria_info=chr(10).join(criteria_info)
|
||||
criteria_info=chr(10).join(criteria_info),
|
||||
context_snippets=context_text
|
||||
)
|
||||
|
||||
try:
|
||||
@ -117,12 +122,14 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def review_structure_ai(technical_criteria: List[ScoringCriteria],
|
||||
preliminary_chapters: List[DocumentChapter]) -> Dict[str, Any]:
|
||||
preliminary_chapters: List[DocumentChapter],
|
||||
document_themes: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""AI审查目录结构
|
||||
|
||||
Args:
|
||||
technical_criteria: 技术评分项
|
||||
preliminary_chapters: 初步生成的章节
|
||||
document_themes: 招标文档主题线索
|
||||
|
||||
Returns:
|
||||
审查结果字典
|
||||
@ -134,10 +141,13 @@ class LLMHelper:
|
||||
|
||||
# 从配置获取提示词
|
||||
prompt_manager = get_prompt_manager()
|
||||
themes_text = document_themes or "(暂无主题线索)"
|
||||
|
||||
review_prompt = prompt_manager.get_toc_prompt(
|
||||
"review_structure",
|
||||
criteria_summary=criteria_summary,
|
||||
chapters_summary=chapters_summary
|
||||
chapters_summary=chapters_summary,
|
||||
document_themes=themes_text
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -58,18 +58,21 @@ class OptimizeWithFeedbackNode(BaseNode, TocNodeBase):
|
||||
return self._update_state(state,
|
||||
adjusted_chapters=optimized_chapters,
|
||||
user_feedback="", # 清空反馈
|
||||
needs_optimization=False)
|
||||
needs_optimization=False,
|
||||
pending_suggestions=[])
|
||||
else:
|
||||
logger.warning("AI优化结果未按用户要求修改,保持原有结构")
|
||||
return self._update_state(state,
|
||||
user_feedback="",
|
||||
needs_optimization=False)
|
||||
needs_optimization=False,
|
||||
pending_suggestions=[])
|
||||
else:
|
||||
# 优化失败,保持原有结构
|
||||
logger.warning("AI优化失败,保持原有结构")
|
||||
return self._update_state(state,
|
||||
user_feedback="",
|
||||
needs_optimization=False)
|
||||
needs_optimization=False,
|
||||
pending_suggestions=[])
|
||||
|
||||
def _optimize_with_ai(self, chapters: List[DocumentChapter], feedback: str) -> Optional[List[DocumentChapter]]:
|
||||
"""使用AI优化目录结构
|
||||
|
||||
@ -10,6 +10,7 @@ from ..base import BaseNode, NodeContext
|
||||
from ...tools.parser import ScoringCriteria, DocumentChapter
|
||||
from .llm_helper import LLMHelper
|
||||
from .base_mixins import TocNodeBase
|
||||
from ...utils.document_context import DocumentContextSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -39,8 +40,21 @@ class ReviewStructureNode(BaseNode, TocNodeBase):
|
||||
technical_criteria = state["technical_criteria"]
|
||||
preliminary_chapters = state["preliminary_chapters"]
|
||||
|
||||
document_context = state.get("document_context")
|
||||
document_themes = None
|
||||
if document_context:
|
||||
try:
|
||||
theme_searcher = DocumentContextSearcher(document_context)
|
||||
document_themes = theme_searcher.summarize_themes()
|
||||
except Exception as exc:
|
||||
self.log_step_info("review_structure", f"主题线索生成失败: {exc}")
|
||||
|
||||
# 执行AI审查
|
||||
review_result = LLMHelper.review_structure_ai(technical_criteria, preliminary_chapters)
|
||||
review_result = LLMHelper.review_structure_ai(
|
||||
technical_criteria,
|
||||
preliminary_chapters,
|
||||
document_themes=document_themes
|
||||
)
|
||||
|
||||
# 根据审查结果添加警告
|
||||
if review_result.get("suggestions"):
|
||||
@ -51,4 +65,8 @@ class ReviewStructureNode(BaseNode, TocNodeBase):
|
||||
optimization_score = review_result.get('optimization_score', 'N/A')
|
||||
self.log_step_info("review_structure", f"AI审查完成,优化评分: {optimization_score}")
|
||||
|
||||
return self._update_state(state, structure_review=review_result)
|
||||
return self._update_state(
|
||||
state,
|
||||
structure_review=review_result,
|
||||
pending_suggestions=review_result.get("suggestions", [])
|
||||
)
|
||||
@ -37,6 +37,10 @@ class UserFeedbackNode(BaseNode, TocNodeBase):
|
||||
raise ValueError("缺少最终章节数据")
|
||||
|
||||
final_chapters = state.get("final_chapters", [])
|
||||
auto_mode = state.get("auto_mode", False)
|
||||
|
||||
if auto_mode:
|
||||
return self._handle_auto_feedback(state, final_chapters)
|
||||
|
||||
# 获取交互处理器
|
||||
interaction_handler = state.get("interaction_handler")
|
||||
@ -99,6 +103,53 @@ class UserFeedbackNode(BaseNode, TocNodeBase):
|
||||
user_feedback=feedback,
|
||||
needs_optimization=True)
|
||||
|
||||
def _handle_auto_feedback(self,
|
||||
state: Dict[str, Any],
|
||||
final_chapters: List[DocumentChapter]) -> Dict[str, Any]:
|
||||
review = state.get("structure_review", {}) or {}
|
||||
suggestions = state.get("pending_suggestions") or review.get("suggestions", []) or []
|
||||
|
||||
max_rounds = state.get("auto_toc_max_rounds", 1)
|
||||
rounds = state.get("auto_optimization_rounds", 0)
|
||||
|
||||
if rounds >= max_rounds:
|
||||
logger.info("自动模式: 达到最大优化次数,自动结束")
|
||||
return self._update_state(state,
|
||||
should_continue=False,
|
||||
user_feedback="",
|
||||
needs_optimization=False,
|
||||
pending_suggestions=[])
|
||||
|
||||
prioritized = [
|
||||
item for item in suggestions
|
||||
if item.get("priority", "low").lower() in {"high", "medium"}
|
||||
]
|
||||
|
||||
if not prioritized:
|
||||
logger.info("自动模式: 无需额外优化,直接结束")
|
||||
return self._update_state(state,
|
||||
should_continue=False,
|
||||
user_feedback="",
|
||||
needs_optimization=False,
|
||||
pending_suggestions=[])
|
||||
|
||||
feedback_lines = []
|
||||
for idx, item in enumerate(prioritized[:5], 1):
|
||||
description = item.get("description", "补充完善章节内容")
|
||||
action_type = item.get("type", "adjust")
|
||||
priority = item.get("priority", "medium")
|
||||
feedback_lines.append(f"{idx}. [{priority}] {action_type}: {description}")
|
||||
|
||||
feedback_text = "\n".join(feedback_lines)
|
||||
logger.info("自动模式: 根据审查建议触发目录优化")
|
||||
|
||||
return self._update_state(state,
|
||||
should_continue=True,
|
||||
user_feedback=feedback_text,
|
||||
needs_optimization=True,
|
||||
pending_suggestions=[],
|
||||
auto_optimization_rounds=rounds + 1)
|
||||
|
||||
def _format_chapters_for_display(self, chapters: List[DocumentChapter], indent_level: int = 0) -> str:
|
||||
"""格式化章节用于显示
|
||||
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
"""工具模块"""
|
||||
|
||||
from .timeout_input import timeout_prompt
|
||||
from .document_context import (
|
||||
DocumentContext,
|
||||
DocumentContextBuilder,
|
||||
DocumentContextSearcher,
|
||||
)
|
||||
|
||||
__all__ = ["timeout_prompt"]
|
||||
__all__ = [
|
||||
"timeout_prompt",
|
||||
"DocumentContext",
|
||||
"DocumentContextBuilder",
|
||||
"DocumentContextSearcher",
|
||||
]
|
||||
|
||||
284
src/bidmaster/utils/document_context.py
Normal file
284
src/bidmaster/utils/document_context.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""招标文档上下文构建与检索工具。
|
||||
|
||||
提供基于招标文件原文的分块、嵌入和相似度搜索能力,
|
||||
用于目录生成阶段的智能参考。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from docx import Document
|
||||
from docx.oxml.table import CT_Tbl
|
||||
from docx.oxml.text.paragraph import CT_P
|
||||
from docx.table import Table
|
||||
from docx.text.paragraph import Paragraph
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from ..config import get_settings
|
||||
from ..tools.parser import BidParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentContextChunk:
|
||||
"""单个上下文块的数据结构。"""
|
||||
|
||||
index: int
|
||||
text: str
|
||||
section: str
|
||||
source_type: str
|
||||
metadata: Dict[str, Any]
|
||||
embedding: List[float]
|
||||
embedding_norm: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentContext:
|
||||
"""完整招标文档的上下文集合。"""
|
||||
|
||||
source_file: str
|
||||
embedding_model: str
|
||||
chunks: List[DocumentContextChunk] = field(default_factory=list)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self.chunks) == 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentContextMatch:
|
||||
"""检索结果结构。"""
|
||||
|
||||
text: str
|
||||
section: str
|
||||
score: float
|
||||
source_type: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class DocumentContextBuilder:
|
||||
"""构建招标文档上下文。"""
|
||||
|
||||
def __init__(self,
|
||||
settings=None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None):
|
||||
self.settings = settings or get_settings()
|
||||
self.chunk_size = chunk_size or getattr(self.settings, "document_context_chunk_size", 800)
|
||||
self.chunk_overlap = chunk_overlap or getattr(self.settings, "document_context_overlap", 200)
|
||||
self.embedding_function = embedding_fn or self._create_embedding_function()
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len,
|
||||
)
|
||||
self.parser = BidParser()
|
||||
|
||||
def build(self, file_path: str) -> DocumentContext:
|
||||
"""构建指定Word文件的上下文。"""
|
||||
try:
|
||||
document = Document(file_path)
|
||||
except Exception as exc:
|
||||
logger.error(f"加载Word文档失败: {exc}")
|
||||
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
|
||||
|
||||
segments = self._collect_segments(document)
|
||||
if not segments:
|
||||
logger.warning("未从招标文件提取到可用文本片段")
|
||||
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
|
||||
|
||||
raw_chunks: List[Dict[str, Any]] = []
|
||||
for segment in segments:
|
||||
pieces = self.text_splitter.split_text(segment["text"])
|
||||
for piece in pieces:
|
||||
cleaned_piece = piece.strip()
|
||||
if not cleaned_piece:
|
||||
continue
|
||||
raw_chunks.append({
|
||||
"text": cleaned_piece,
|
||||
"section": segment["section"],
|
||||
"source_type": segment["type"],
|
||||
"metadata": {
|
||||
"source_type": segment["type"],
|
||||
"source_section": segment["section"],
|
||||
}
|
||||
})
|
||||
|
||||
if not raw_chunks:
|
||||
logger.warning("文本分块后为空,跳过上下文构建")
|
||||
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
|
||||
|
||||
embeddings = self._embed_texts([chunk["text"] for chunk in raw_chunks])
|
||||
chunks: List[DocumentContextChunk] = []
|
||||
for index, (chunk_info, vector) in enumerate(zip(raw_chunks, embeddings)):
|
||||
vector_list = [float(v) for v in vector]
|
||||
norm = float(np.linalg.norm(vector_list)) or 1e-8
|
||||
chunks.append(DocumentContextChunk(
|
||||
index=index,
|
||||
text=chunk_info["text"],
|
||||
section=chunk_info["section"],
|
||||
source_type=chunk_info["source_type"],
|
||||
metadata=chunk_info["metadata"],
|
||||
embedding=vector_list,
|
||||
embedding_norm=norm
|
||||
))
|
||||
|
||||
logger.info(f"上下文构建完成: {len(chunks)} 个片段")
|
||||
return DocumentContext(
|
||||
source_file=file_path,
|
||||
embedding_model=self.settings.embedding_model,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
def _collect_segments(self, document: Document) -> List[Dict[str, Any]]:
|
||||
segments: List[Dict[str, Any]] = []
|
||||
current_section = "全文"
|
||||
|
||||
for block in self._iter_block_items(document):
|
||||
if isinstance(block, Paragraph):
|
||||
text = block.text.strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
if self._is_heading(block):
|
||||
current_section = text
|
||||
segments.append({"text": text, "section": current_section, "type": "heading"})
|
||||
else:
|
||||
segments.append({"text": text, "section": current_section, "type": "paragraph"})
|
||||
|
||||
elif isinstance(block, Table):
|
||||
table_text = self.parser.extract_table_text(block)
|
||||
if table_text:
|
||||
segments.append({"text": table_text, "section": current_section, "type": "table"})
|
||||
|
||||
return segments
|
||||
|
||||
def _iter_block_items(self, document: Document):
|
||||
body = document.element.body
|
||||
for child in body.iterchildren():
|
||||
if isinstance(child, CT_P):
|
||||
yield Paragraph(child, document)
|
||||
elif isinstance(child, CT_Tbl):
|
||||
yield Table(child, document)
|
||||
|
||||
def _is_heading(self, paragraph: Paragraph) -> bool:
|
||||
style_name = getattr(paragraph.style, "name", "") or ""
|
||||
return style_name.startswith("Heading") or "标题" in style_name
|
||||
|
||||
def _embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
try:
|
||||
vectors = self.embedding_function(texts)
|
||||
return [list(vec) for vec in vectors]
|
||||
except Exception as exc:
|
||||
logger.error(f"计算上下文嵌入失败: {exc}")
|
||||
# 退化为简单的稀疏向量,避免完全失败
|
||||
return [[float(len(text))] for text in texts]
|
||||
|
||||
def _create_embedding_function(self):
|
||||
model_name = self.settings.embedding_model
|
||||
if model_name.startswith("text-embedding-"):
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=self.settings.api_key,
|
||||
model_name=model_name
|
||||
)
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
||||
|
||||
|
||||
class DocumentContextSearcher:
|
||||
"""基于向量相似度的上下文检索器。"""
|
||||
|
||||
def __init__(self,
|
||||
context: Optional[DocumentContext],
|
||||
settings=None,
|
||||
embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None,
|
||||
top_k: Optional[int] = None):
|
||||
self.context = context or DocumentContext("", "", [])
|
||||
self.settings = settings or get_settings()
|
||||
self.top_k = top_k or getattr(self.settings, "document_context_top_k", 3)
|
||||
self.embedding_function = embedding_fn or self._create_embedding_function()
|
||||
|
||||
self._chunk_matrix = None
|
||||
self._normalized_matrix = None
|
||||
self._prepare_matrix()
|
||||
|
||||
def _prepare_matrix(self) -> None:
|
||||
if self.context.is_empty():
|
||||
return
|
||||
|
||||
embeddings = np.array([chunk.embedding for chunk in self.context.chunks], dtype=np.float32)
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1e-8
|
||||
self._chunk_matrix = embeddings
|
||||
self._normalized_matrix = embeddings / norms
|
||||
|
||||
def search(self, query: str, top_k: Optional[int] = None) -> List[DocumentContextMatch]:
|
||||
if not query or self._normalized_matrix is None:
|
||||
return []
|
||||
|
||||
k = top_k or self.top_k
|
||||
query_vector = self._embed_query(query)
|
||||
if query_vector is None:
|
||||
return []
|
||||
|
||||
query_norm = np.linalg.norm(query_vector)
|
||||
if query_norm == 0:
|
||||
return []
|
||||
normalized_query = query_vector / query_norm
|
||||
similarities = self._normalized_matrix @ normalized_query
|
||||
|
||||
top_indices = similarities.argsort()[-k:][::-1]
|
||||
matches: List[DocumentContextMatch] = []
|
||||
for idx in top_indices:
|
||||
score = float(similarities[idx])
|
||||
if score <= 0:
|
||||
continue
|
||||
chunk = self.context.chunks[int(idx)]
|
||||
matches.append(DocumentContextMatch(
|
||||
text=chunk.text,
|
||||
section=chunk.section,
|
||||
score=score,
|
||||
source_type=chunk.source_type,
|
||||
metadata=chunk.metadata
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def summarize_themes(self, max_items: int = 5) -> str:
|
||||
if self.context.is_empty():
|
||||
return "(暂无主题线索)"
|
||||
|
||||
section_map: Dict[str, List[str]] = {}
|
||||
for chunk in self.context.chunks:
|
||||
section = chunk.section or "未命名章节"
|
||||
section_map.setdefault(section, []).append(chunk.text.strip())
|
||||
|
||||
sorted_sections = sorted(section_map.items(), key=lambda item: len(item[1]), reverse=True)
|
||||
lines = []
|
||||
for section, excerpts in sorted_sections[:max_items]:
|
||||
preview = excerpts[0].replace("\n", " ")[:120]
|
||||
lines.append(f"- {section}: {preview}")
|
||||
|
||||
return "\n".join(lines) if lines else "(暂无主题线索)"
|
||||
|
||||
def _embed_query(self, query: str) -> Optional[np.ndarray]:
|
||||
try:
|
||||
vector = self.embedding_function([query])[0]
|
||||
return np.array(vector, dtype=np.float32)
|
||||
except Exception as exc:
|
||||
logger.error(f"上下文检索计算查询嵌入失败: {exc}")
|
||||
return None
|
||||
|
||||
def _create_embedding_function(self):
|
||||
model_name = getattr(self.context, "embedding_model", None) or self.settings.embedding_model
|
||||
if model_name.startswith("text-embedding-"):
|
||||
return embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=self.settings.api_key,
|
||||
model_name=model_name
|
||||
)
|
||||
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
|
||||
70
tests/unit/test_document_context.py
Normal file
70
tests/unit/test_document_context.py
Normal file
@ -0,0 +1,70 @@
|
||||
from pathlib import Path
|
||||
|
||||
from docx import Document
|
||||
|
||||
from bidmaster.utils.document_context import (
|
||||
DocumentContextBuilder,
|
||||
DocumentContextSearcher,
|
||||
)
|
||||
|
||||
|
||||
def _create_sample_doc(doc_path: Path) -> None:
|
||||
doc = Document()
|
||||
doc.add_heading("第一章 项目总体概述", level=1)
|
||||
doc.add_paragraph("本项目聚焦城市智慧照明系统建设,强调云边协同与多维感知能力。")
|
||||
doc.add_heading("第二章 建设目标", level=1)
|
||||
doc.add_paragraph("目标包括统一管控平台、智能终端、数据中台三大部分。")
|
||||
|
||||
table = doc.add_table(rows=2, cols=2)
|
||||
table.rows[0].cells[0].text = "指标"
|
||||
table.rows[0].cells[1].text = "要求"
|
||||
table.rows[1].cells[0].text = "系统稳定性"
|
||||
table.rows[1].cells[1].text = "7x24小时无故障运行"
|
||||
|
||||
doc.save(doc_path)
|
||||
|
||||
|
||||
def _dummy_embedding(texts):
|
||||
return [[float(len(text))] for text in texts]
|
||||
|
||||
|
||||
def test_document_context_builder_creates_chunks(tmp_path):
|
||||
doc_path = tmp_path / "context.docx"
|
||||
_create_sample_doc(doc_path)
|
||||
|
||||
builder = DocumentContextBuilder(
|
||||
chunk_size=120,
|
||||
chunk_overlap=10,
|
||||
embedding_fn=_dummy_embedding,
|
||||
)
|
||||
|
||||
context = builder.build(str(doc_path))
|
||||
|
||||
assert not context.is_empty()
|
||||
assert all(chunk.embedding for chunk in context.chunks)
|
||||
assert any("项目总体概述" in chunk.section for chunk in context.chunks)
|
||||
|
||||
|
||||
def test_document_context_searcher_returns_matches(tmp_path):
|
||||
doc_path = tmp_path / "search.docx"
|
||||
_create_sample_doc(doc_path)
|
||||
|
||||
builder = DocumentContextBuilder(
|
||||
chunk_size=80,
|
||||
chunk_overlap=10,
|
||||
embedding_fn=_dummy_embedding,
|
||||
)
|
||||
context = builder.build(str(doc_path))
|
||||
|
||||
searcher = DocumentContextSearcher(
|
||||
context,
|
||||
embedding_fn=_dummy_embedding,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
matches = searcher.search("智慧照明平台")
|
||||
assert matches
|
||||
assert matches[0].score > 0
|
||||
|
||||
themes = searcher.summarize_themes()
|
||||
assert "第一章" in themes
|
||||
140
tests/unit/test_toc_generation.py
Normal file
140
tests/unit/test_toc_generation.py
Normal file
@ -0,0 +1,140 @@
|
||||
from bidmaster.nodes.base import NodeContext
|
||||
from bidmaster.nodes.toc import generate_sub_chapters as gen_module
|
||||
from bidmaster.nodes.toc.generate_sub_chapters import GenerateSubChaptersNode
|
||||
from bidmaster.nodes.toc.user_feedback import UserFeedbackNode
|
||||
from bidmaster.tools.parser import DocumentChapter, ScoringCriteria, TechnicalCategory
|
||||
from bidmaster.utils.document_context import DocumentContext, DocumentContextMatch
|
||||
|
||||
|
||||
def test_generate_sub_chapters_uses_context(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_generate(criteria_list, parent_chapter, context_snippets=None):
|
||||
captured["context"] = context_snippets
|
||||
return [{"title": "示例小节", "level": 2, "score": 5, "children": []}]
|
||||
|
||||
class DummySearcher:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def search(self, query, top_k=None):
|
||||
return [
|
||||
DocumentContextMatch(
|
||||
text="系统采用云边协同架构,强调稳定性。",
|
||||
section="项目概述",
|
||||
score=0.95,
|
||||
source_type="paragraph",
|
||||
metadata={},
|
||||
)
|
||||
]
|
||||
|
||||
monkeypatch.setattr(gen_module, "DocumentContextSearcher", DummySearcher)
|
||||
monkeypatch.setattr(gen_module.LLMHelper, "generate_sub_chapters_ai", staticmethod(fake_generate))
|
||||
|
||||
chapter = DocumentChapter(
|
||||
id="chapter_01_technical_solution",
|
||||
title="技术方案",
|
||||
level=1,
|
||||
score=30,
|
||||
)
|
||||
criteria = [
|
||||
ScoringCriteria(
|
||||
item_name="系统架构先进性",
|
||||
max_score=10,
|
||||
description="阐述云边端协同",
|
||||
category=TechnicalCategory.TECHNICAL_SOLUTION,
|
||||
chapter_id="chapter_01_technical_solution",
|
||||
original_index=0,
|
||||
)
|
||||
]
|
||||
|
||||
state = {
|
||||
"preliminary_chapters": [chapter],
|
||||
"technical_criteria": criteria,
|
||||
"document_context": DocumentContext("demo.docx", "test-model", []),
|
||||
}
|
||||
|
||||
node = GenerateSubChaptersNode()
|
||||
result = node.execute(state, NodeContext())
|
||||
|
||||
assert result["preliminary_chapters"][0].children
|
||||
assert captured["context"] and "项目概述" in captured["context"][0]
|
||||
|
||||
|
||||
def test_user_feedback_auto_triggers_optimization():
|
||||
chapters = [
|
||||
DocumentChapter(
|
||||
id="chapter_01",
|
||||
title="技术方案",
|
||||
level=1,
|
||||
score=30,
|
||||
)
|
||||
]
|
||||
|
||||
state = {
|
||||
"final_chapters": chapters,
|
||||
"structure_review": {
|
||||
"suggestions": [
|
||||
{"description": "补充智慧运维章节", "priority": "high", "type": "add"},
|
||||
{"description": "可选优化", "priority": "low", "type": "modify"},
|
||||
]
|
||||
},
|
||||
"auto_mode": True,
|
||||
"auto_toc_max_rounds": 2,
|
||||
"auto_optimization_rounds": 0,
|
||||
}
|
||||
|
||||
node = UserFeedbackNode()
|
||||
result = node.execute(state, NodeContext())
|
||||
|
||||
assert result["needs_optimization"] is True
|
||||
assert "补充" in result["user_feedback"]
|
||||
assert result["auto_optimization_rounds"] == 1
|
||||
assert result["pending_suggestions"] == []
|
||||
|
||||
|
||||
def test_user_feedback_auto_skips_when_no_priority():
|
||||
chapters = [
|
||||
DocumentChapter(
|
||||
id="chapter_02",
|
||||
title="实施计划",
|
||||
level=1,
|
||||
score=20,
|
||||
)
|
||||
]
|
||||
|
||||
state = {
|
||||
"final_chapters": chapters,
|
||||
"structure_review": {"suggestions": [{"description": "微调描述", "priority": "low", "type": "modify"}]},
|
||||
"auto_mode": True,
|
||||
"auto_toc_max_rounds": 1,
|
||||
}
|
||||
|
||||
node = UserFeedbackNode()
|
||||
result = node.execute(state, NodeContext())
|
||||
|
||||
assert result["needs_optimization"] is False
|
||||
assert result["user_feedback"] == ""
|
||||
|
||||
|
||||
def test_user_feedback_auto_respects_round_limit():
|
||||
chapters = [
|
||||
DocumentChapter(id="chapter_03", title="服务方案", level=1, score=10)
|
||||
]
|
||||
|
||||
state = {
|
||||
"final_chapters": chapters,
|
||||
"structure_review": {
|
||||
"suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}]
|
||||
},
|
||||
"auto_mode": True,
|
||||
"auto_toc_max_rounds": 1,
|
||||
"auto_optimization_rounds": 1,
|
||||
"pending_suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}],
|
||||
}
|
||||
|
||||
node = UserFeedbackNode()
|
||||
result = node.execute(state, NodeContext())
|
||||
|
||||
assert result["needs_optimization"] is False
|
||||
assert result["pending_suggestions"] == []
|
||||
Loading…
Reference in New Issue
Block a user