diff --git a/config/prompts.yaml b/config/prompts.yaml index f19d48e..2debeb8 100644 --- a/config/prompts.yaml +++ b/config/prompts.yaml @@ -113,11 +113,14 @@ toc_prompts: 【大类别】: {parent_title} 【评分项】: {criteria_info} + 【招标上下文摘录】: + {context_snippets} 生成要求: 1. 为每个评分项生成对应的子标题名称(不要包含编号) 2. 重要评分项可添加三级子标题(不要包含编号) - 3. 只返回标题文本,编号由Word自动管理 + 3. 充分参考上下文摘录中的专业术语和业务背景,体现定制化判断 + 4. 只返回标题文本,编号由Word自动管理 返回JSON格式: {{ @@ -137,11 +140,14 @@ toc_prompts: 【当前生成的章节结构】: {chapters_summary} + 【招标主题线索】: + {document_themes} 【审查要求】: 1. 是否缺少重要的标准章节? 2. 章节顺序是否合理? - 3. 每个评分项是否都有对应章节? + 3. 是否覆盖了主题线索中的关键内容? + 4. 每个评分项是否都有对应章节? 返回JSON格式: {{ diff --git a/src/bidmaster/agents/analysis.py b/src/bidmaster/agents/analysis.py index 8f86012..e178c77 100644 --- a/src/bidmaster/agents/analysis.py +++ b/src/bidmaster/agents/analysis.py @@ -18,6 +18,7 @@ from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationIt from ..config import get_settings from ..nodes.toc.base_mixins import WorkflowUtilsMixin from .base import BaseAgentFactory +from ..utils.document_context import DocumentContext, DocumentContextBuilder logger = logging.getLogger(__name__) @@ -60,9 +61,15 @@ class AnalysisAgentState(TypedDict): # 章节生成过程 preliminary_chapters: List[DocumentChapter] # 初步生成的章节 structure_review: Dict[str, Any] # AI审查结果 + document_context: DocumentContext | None # 招标文档上下文 + auto_mode: bool # 是否启用自动目录优化 + user_feedback: str # 自动/人工反馈内容 + pending_suggestions: List[Dict[str, Any]] # 待处理的AI建议 + auto_optimization_rounds: int # 自动优化迭代次数 + auto_toc_max_rounds: int # 最大自动优化次数 # 最终输出 - bid_structure: BidStructure + bid_structure: BidStructure | None # 错误处理 error: str @@ -154,6 +161,14 @@ def extract_tables_node(state: AnalysisAgentState) -> AnalysisAgentState: state["current_step"] = "extract_tables" state["progress"] = 0.25 + # 构建全文上下文供后续目录生成使用 + try: + context_builder = DocumentContextBuilder() + document_context = context_builder.build(source_file) + state["document_context"] = document_context + except Exception as exc: + logger.warning(f"构建招标文档上下文失败: {exc}") + except Exception as e: logger.error(f"表格提取失败: {e}") state["error"] = str(e) @@ -284,7 +299,9 @@ def generate_toc_with_agent_node(state: AnalysisAgentState) -> AnalysisAgentStat result = toc_agent.generate_toc_sync( technical_criteria=technical_criteria, generation_mode=generation_mode, - template_file=template_file + template_file=template_file, + document_context=state.get("document_context"), + auto_mode=state.get("auto_mode") ) # 检查是否有错误 @@ -475,6 +492,12 @@ class AnalysisAgent(BaseAgentFactory): deviation_items=[], preliminary_chapters=[], structure_review={}, + document_context=None, + auto_mode=self.settings.auto_toc_mode, + user_feedback="", + pending_suggestions=[], + auto_optimization_rounds=0, + auto_toc_max_rounds=self.settings.auto_toc_max_rounds, bid_structure=None, error="", warnings=[] @@ -482,7 +505,10 @@ class AnalysisAgent(BaseAgentFactory): try: # 执行LangGraph工作流 - final_state = await self.graph.ainvoke(initial_state) + final_state = await self.graph.ainvoke( + initial_state, + config={"recursion_limit": self.settings.langgraph_recursion_limit} + ) # 构建执行结果 if final_state.get("error"): diff --git a/src/bidmaster/agents/base.py b/src/bidmaster/agents/base.py index cab7124..a138a14 100644 --- a/src/bidmaster/agents/base.py +++ b/src/bidmaster/agents/base.py @@ -8,6 +8,7 @@ from typing import List, Dict, Any, Optional, Callable, Tuple from langgraph.graph import StateGraph, END from ..nodes.base import BaseNode, NodeContext +from ..config import get_settings logger = logging.getLogger(__name__) @@ -227,6 +228,8 @@ class BaseAgent: """ self.builder = builder self.graph = builder.build() + self.settings = get_settings() + self.recursion_limit = getattr(self.settings, "langgraph_recursion_limit", 120) async def execute(self, initial_state: Dict[str, Any]) -> Dict[str, Any]: """异步执行Agent @@ -240,7 +243,10 @@ class BaseAgent: logger.info(f"开始执行Agent,初始状态keys: {list(initial_state.keys())}") try: - final_state = await self.graph.ainvoke(initial_state) + final_state = await self.graph.ainvoke( + initial_state, + config={"recursion_limit": self.recursion_limit} + ) logger.info("Agent执行完成") return final_state except Exception as e: diff --git a/src/bidmaster/agents/builders/toc_builder.py b/src/bidmaster/agents/builders/toc_builder.py index ec03aef..5d8da5d 100644 --- a/src/bidmaster/agents/builders/toc_builder.py +++ b/src/bidmaster/agents/builders/toc_builder.py @@ -145,7 +145,9 @@ class TocAgent(BaseAgent, BaseAgentFactory): async def generate_toc(self, technical_criteria: list, generation_mode: Optional[str] = None, - template_file: Optional[str] = None) -> Dict[str, Any]: + template_file: Optional[str] = None, + document_context: Optional[Any] = None, + auto_mode: Optional[bool] = None) -> Dict[str, Any]: """生成目录结构 Args: @@ -165,7 +167,10 @@ class TocAgent(BaseAgent, BaseAgentFactory): "structure_review": {}, "final_chapters": [], "error": "", - "warnings": [] + "warnings": [], + "document_context": document_context, + "auto_mode": True if auto_mode is None else auto_mode, + "user_feedback": "" } # 如果提供了预设参数,添加到状态中 @@ -182,7 +187,9 @@ class TocAgent(BaseAgent, BaseAgentFactory): def generate_toc_sync(self, technical_criteria: list, generation_mode: Optional[str] = None, - template_file: Optional[str] = None) -> Dict[str, Any]: + template_file: Optional[str] = None, + document_context: Optional[Any] = None, + auto_mode: Optional[bool] = None) -> Dict[str, Any]: """同步生成目录结构 Args: @@ -195,5 +202,5 @@ class TocAgent(BaseAgent, BaseAgentFactory): """ import asyncio return asyncio.run( - self.generate_toc(technical_criteria, generation_mode, template_file) + self.generate_toc(technical_criteria, generation_mode, template_file, document_context, auto_mode) ) \ No newline at end of file diff --git a/src/bidmaster/config/settings.py b/src/bidmaster/config/settings.py index 2038cb2..70930e4 100644 --- a/src/bidmaster/config/settings.py +++ b/src/bidmaster/config/settings.py @@ -52,6 +52,11 @@ class Settings(BaseSettings): embedding_model: str = Field(default="text-embedding-3-small", description="嵌入模型") chunk_size: int = Field(default=1000, description="文档块大小") chunk_overlap: int = Field(default=200, description="块重叠大小") + document_context_chunk_size: int = Field(default=800, description="目录生成上下文分块大小") + document_context_overlap: int = Field(default=150, description="目录生成上下文块重叠") + document_context_top_k: int = Field(default=3, description="上下文检索结果数量") + langgraph_recursion_limit: int = Field(default=120, description="LangGraph执行的递归上限") + auto_toc_max_rounds: int = Field(default=2, description="自动目录优化最大迭代次数") # 性能配置 max_workers: int = Field(default=4, description="最大工作线程数") @@ -59,6 +64,7 @@ class Settings(BaseSettings): # 交互配置 interaction_timeout: int = Field(default=60, description="用户交互超时时间(秒)") + auto_toc_mode: bool = Field(default=True, description="目录生成是否默认自动处理反馈") # 日志配置 log_level: str = Field(default="INFO", description="日志级别") diff --git a/src/bidmaster/nodes/toc/generate_sub_chapters.py b/src/bidmaster/nodes/toc/generate_sub_chapters.py index 2ce2724..8a25856 100644 --- a/src/bidmaster/nodes/toc/generate_sub_chapters.py +++ b/src/bidmaster/nodes/toc/generate_sub_chapters.py @@ -4,7 +4,7 @@ """ import logging -from typing import Dict, List, Any +from typing import Dict, List, Any, Optional from ..base import BaseNode, NodeContext from ...tools.parser import ScoringCriteria, DocumentChapter @@ -12,6 +12,7 @@ from .category_manager import CategoryManager from .factories import ChapterFactory from .llm_helper import LLMHelper from .base_mixins import TocNodeBase +from ...utils.document_context import DocumentContextSearcher logger = logging.getLogger(__name__) @@ -40,11 +41,23 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase): preliminary_chapters = state["preliminary_chapters"] technical_criteria = state["technical_criteria"] + document_context = state.get("document_context") + + context_searcher: Optional[DocumentContextSearcher] = None + if document_context: + try: + context_searcher = DocumentContextSearcher(document_context) + except Exception as exc: + self.log_step_info("context_searcher", f"上下文检索器初始化失败: {exc}") # 为每个章节生成子标题 enhanced_chapters = [] for chapter in preliminary_chapters: - enhanced_chapter = self._enhance_chapter_with_subs(chapter, technical_criteria) + enhanced_chapter = self._enhance_chapter_with_subs( + chapter, + technical_criteria, + context_searcher + ) enhanced_chapters.append(enhanced_chapter) self.log_step_info("generate_sub_chapters", f"完成{len(enhanced_chapters)}个章节的子标题生成") @@ -53,7 +66,8 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase): def _enhance_chapter_with_subs(self, chapter: DocumentChapter, - technical_criteria: List[ScoringCriteria]) -> DocumentChapter: + technical_criteria: List[ScoringCriteria], + context_searcher: Optional[DocumentContextSearcher] = None) -> DocumentChapter: """为章节增强子标题 Args: @@ -71,8 +85,14 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase): self.log_step_info("enhance_chapter", warning_msg) return chapter + context_snippets = self._collect_context_snippets(chapter, corresponding_criteria, context_searcher) + # 使用AI生成子标题 - sub_chapters_data = LLMHelper.generate_sub_chapters_ai(corresponding_criteria, chapter) + sub_chapters_data = LLMHelper.generate_sub_chapters_ai( + corresponding_criteria, + chapter, + context_snippets=context_snippets + ) if sub_chapters_data: chapter.children = ChapterFactory.create_chapters_from_ai_response(chapter, sub_chapters_data) @@ -82,4 +102,34 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase): self.log_step_info("enhance_chapter", f"章节 {chapter.title} AI生成子标题失败") raise RuntimeError(error_msg) - return chapter \ No newline at end of file + return chapter + + def _collect_context_snippets(self, + chapter: DocumentChapter, + criteria: List[ScoringCriteria], + context_searcher: Optional[DocumentContextSearcher]) -> Optional[List[str]]: + if not context_searcher or not criteria: + return None + + query_parts = [chapter.title] + for item in criteria: + query_parts.append(item.item_name) + if item.description: + query_parts.append(item.description) + + query = ";".join(part for part in query_parts if part) + if not query: + return None + + matches = context_searcher.search(query) + if not matches: + return None + + snippets: List[str] = [] + for match in matches: + snippet_text = match.text.strip().replace("\n", " ") + if not snippet_text: + continue + snippets.append(f"{match.section}: {snippet_text[:200]}") + + return snippets or None \ No newline at end of file diff --git a/src/bidmaster/nodes/toc/llm_helper.py b/src/bidmaster/nodes/toc/llm_helper.py index cb3ca8b..c3be050 100644 --- a/src/bidmaster/nodes/toc/llm_helper.py +++ b/src/bidmaster/nodes/toc/llm_helper.py @@ -80,12 +80,14 @@ class LLMHelper: @staticmethod def generate_sub_chapters_ai(criteria_list: List[ScoringCriteria], - parent_chapter: DocumentChapter) -> Optional[List[Dict[str, Any]]]: + parent_chapter: DocumentChapter, + context_snippets: Optional[List[str]] = None) -> Optional[List[Dict[str, Any]]]: """AI生成子章节数据 Args: criteria_list: 对应的评分项列表 parent_chapter: 父章节 + context_snippets: 招标文档的相关片段 Returns: 子章节数据列表,失败时返回None @@ -97,10 +99,13 @@ class LLMHelper: # 从配置获取提示词 prompt_manager = get_prompt_manager() + context_text = "\n".join(context_snippets) if context_snippets else "(暂无上下文摘录)" + prompt = prompt_manager.get_toc_prompt( "generate_sub_chapters", parent_title=parent_chapter.title, - criteria_info=chr(10).join(criteria_info) + criteria_info=chr(10).join(criteria_info), + context_snippets=context_text ) try: @@ -117,12 +122,14 @@ class LLMHelper: @staticmethod def review_structure_ai(technical_criteria: List[ScoringCriteria], - preliminary_chapters: List[DocumentChapter]) -> Dict[str, Any]: + preliminary_chapters: List[DocumentChapter], + document_themes: Optional[str] = None) -> Dict[str, Any]: """AI审查目录结构 Args: technical_criteria: 技术评分项 preliminary_chapters: 初步生成的章节 + document_themes: 招标文档主题线索 Returns: 审查结果字典 @@ -134,10 +141,13 @@ class LLMHelper: # 从配置获取提示词 prompt_manager = get_prompt_manager() + themes_text = document_themes or "(暂无主题线索)" + review_prompt = prompt_manager.get_toc_prompt( "review_structure", criteria_summary=criteria_summary, - chapters_summary=chapters_summary + chapters_summary=chapters_summary, + document_themes=themes_text ) try: diff --git a/src/bidmaster/nodes/toc/optimize_with_feedback.py b/src/bidmaster/nodes/toc/optimize_with_feedback.py index 86b9034..bcd649e 100644 --- a/src/bidmaster/nodes/toc/optimize_with_feedback.py +++ b/src/bidmaster/nodes/toc/optimize_with_feedback.py @@ -58,18 +58,21 @@ class OptimizeWithFeedbackNode(BaseNode, TocNodeBase): return self._update_state(state, adjusted_chapters=optimized_chapters, user_feedback="", # 清空反馈 - needs_optimization=False) + needs_optimization=False, + pending_suggestions=[]) else: logger.warning("AI优化结果未按用户要求修改,保持原有结构") return self._update_state(state, user_feedback="", - needs_optimization=False) + needs_optimization=False, + pending_suggestions=[]) else: # 优化失败,保持原有结构 logger.warning("AI优化失败,保持原有结构") return self._update_state(state, user_feedback="", - needs_optimization=False) + needs_optimization=False, + pending_suggestions=[]) def _optimize_with_ai(self, chapters: List[DocumentChapter], feedback: str) -> Optional[List[DocumentChapter]]: """使用AI优化目录结构 diff --git a/src/bidmaster/nodes/toc/review_structure.py b/src/bidmaster/nodes/toc/review_structure.py index fa095f6..1606303 100644 --- a/src/bidmaster/nodes/toc/review_structure.py +++ b/src/bidmaster/nodes/toc/review_structure.py @@ -10,6 +10,7 @@ from ..base import BaseNode, NodeContext from ...tools.parser import ScoringCriteria, DocumentChapter from .llm_helper import LLMHelper from .base_mixins import TocNodeBase +from ...utils.document_context import DocumentContextSearcher logger = logging.getLogger(__name__) @@ -39,8 +40,21 @@ class ReviewStructureNode(BaseNode, TocNodeBase): technical_criteria = state["technical_criteria"] preliminary_chapters = state["preliminary_chapters"] + document_context = state.get("document_context") + document_themes = None + if document_context: + try: + theme_searcher = DocumentContextSearcher(document_context) + document_themes = theme_searcher.summarize_themes() + except Exception as exc: + self.log_step_info("review_structure", f"主题线索生成失败: {exc}") + # 执行AI审查 - review_result = LLMHelper.review_structure_ai(technical_criteria, preliminary_chapters) + review_result = LLMHelper.review_structure_ai( + technical_criteria, + preliminary_chapters, + document_themes=document_themes + ) # 根据审查结果添加警告 if review_result.get("suggestions"): @@ -51,4 +65,8 @@ class ReviewStructureNode(BaseNode, TocNodeBase): optimization_score = review_result.get('optimization_score', 'N/A') self.log_step_info("review_structure", f"AI审查完成,优化评分: {optimization_score}") - return self._update_state(state, structure_review=review_result) \ No newline at end of file + return self._update_state( + state, + structure_review=review_result, + pending_suggestions=review_result.get("suggestions", []) + ) \ No newline at end of file diff --git a/src/bidmaster/nodes/toc/user_feedback.py b/src/bidmaster/nodes/toc/user_feedback.py index bf66923..3dd09bc 100644 --- a/src/bidmaster/nodes/toc/user_feedback.py +++ b/src/bidmaster/nodes/toc/user_feedback.py @@ -37,6 +37,10 @@ class UserFeedbackNode(BaseNode, TocNodeBase): raise ValueError("缺少最终章节数据") final_chapters = state.get("final_chapters", []) + auto_mode = state.get("auto_mode", False) + + if auto_mode: + return self._handle_auto_feedback(state, final_chapters) # 获取交互处理器 interaction_handler = state.get("interaction_handler") @@ -99,6 +103,53 @@ class UserFeedbackNode(BaseNode, TocNodeBase): user_feedback=feedback, needs_optimization=True) + def _handle_auto_feedback(self, + state: Dict[str, Any], + final_chapters: List[DocumentChapter]) -> Dict[str, Any]: + review = state.get("structure_review", {}) or {} + suggestions = state.get("pending_suggestions") or review.get("suggestions", []) or [] + + max_rounds = state.get("auto_toc_max_rounds", 1) + rounds = state.get("auto_optimization_rounds", 0) + + if rounds >= max_rounds: + logger.info("自动模式: 达到最大优化次数,自动结束") + return self._update_state(state, + should_continue=False, + user_feedback="", + needs_optimization=False, + pending_suggestions=[]) + + prioritized = [ + item for item in suggestions + if item.get("priority", "low").lower() in {"high", "medium"} + ] + + if not prioritized: + logger.info("自动模式: 无需额外优化,直接结束") + return self._update_state(state, + should_continue=False, + user_feedback="", + needs_optimization=False, + pending_suggestions=[]) + + feedback_lines = [] + for idx, item in enumerate(prioritized[:5], 1): + description = item.get("description", "补充完善章节内容") + action_type = item.get("type", "adjust") + priority = item.get("priority", "medium") + feedback_lines.append(f"{idx}. [{priority}] {action_type}: {description}") + + feedback_text = "\n".join(feedback_lines) + logger.info("自动模式: 根据审查建议触发目录优化") + + return self._update_state(state, + should_continue=True, + user_feedback=feedback_text, + needs_optimization=True, + pending_suggestions=[], + auto_optimization_rounds=rounds + 1) + def _format_chapters_for_display(self, chapters: List[DocumentChapter], indent_level: int = 0) -> str: """格式化章节用于显示 diff --git a/src/bidmaster/utils/__init__.py b/src/bidmaster/utils/__init__.py index f29ad70..01be6d4 100644 --- a/src/bidmaster/utils/__init__.py +++ b/src/bidmaster/utils/__init__.py @@ -1,5 +1,15 @@ """工具模块""" from .timeout_input import timeout_prompt +from .document_context import ( + DocumentContext, + DocumentContextBuilder, + DocumentContextSearcher, +) -__all__ = ["timeout_prompt"] +__all__ = [ + "timeout_prompt", + "DocumentContext", + "DocumentContextBuilder", + "DocumentContextSearcher", +] diff --git a/src/bidmaster/utils/document_context.py b/src/bidmaster/utils/document_context.py new file mode 100644 index 0000000..c028819 --- /dev/null +++ b/src/bidmaster/utils/document_context.py @@ -0,0 +1,284 @@ +"""招标文档上下文构建与检索工具。 + +提供基于招标文件原文的分块、嵌入和相似度搜索能力, +用于目录生成阶段的智能参考。 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional + +import numpy as np +from docx import Document +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from docx.table import Table +from docx.text.paragraph import Paragraph +from langchain.text_splitter import RecursiveCharacterTextSplitter +from chromadb.utils import embedding_functions + +from ..config import get_settings +from ..tools.parser import BidParser + +logger = logging.getLogger(__name__) + + +@dataclass +class DocumentContextChunk: + """单个上下文块的数据结构。""" + + index: int + text: str + section: str + source_type: str + metadata: Dict[str, Any] + embedding: List[float] + embedding_norm: float + + +@dataclass +class DocumentContext: + """完整招标文档的上下文集合。""" + + source_file: str + embedding_model: str + chunks: List[DocumentContextChunk] = field(default_factory=list) + + def is_empty(self) -> bool: + return len(self.chunks) == 0 + + +@dataclass +class DocumentContextMatch: + """检索结果结构。""" + + text: str + section: str + score: float + source_type: str + metadata: Dict[str, Any] + + +class DocumentContextBuilder: + """构建招标文档上下文。""" + + def __init__(self, + settings=None, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None): + self.settings = settings or get_settings() + self.chunk_size = chunk_size or getattr(self.settings, "document_context_chunk_size", 800) + self.chunk_overlap = chunk_overlap or getattr(self.settings, "document_context_overlap", 200) + self.embedding_function = embedding_fn or self._create_embedding_function() + self.text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + length_function=len, + ) + self.parser = BidParser() + + def build(self, file_path: str) -> DocumentContext: + """构建指定Word文件的上下文。""" + try: + document = Document(file_path) + except Exception as exc: + logger.error(f"加载Word文档失败: {exc}") + return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model) + + segments = self._collect_segments(document) + if not segments: + logger.warning("未从招标文件提取到可用文本片段") + return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model) + + raw_chunks: List[Dict[str, Any]] = [] + for segment in segments: + pieces = self.text_splitter.split_text(segment["text"]) + for piece in pieces: + cleaned_piece = piece.strip() + if not cleaned_piece: + continue + raw_chunks.append({ + "text": cleaned_piece, + "section": segment["section"], + "source_type": segment["type"], + "metadata": { + "source_type": segment["type"], + "source_section": segment["section"], + } + }) + + if not raw_chunks: + logger.warning("文本分块后为空,跳过上下文构建") + return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model) + + embeddings = self._embed_texts([chunk["text"] for chunk in raw_chunks]) + chunks: List[DocumentContextChunk] = [] + for index, (chunk_info, vector) in enumerate(zip(raw_chunks, embeddings)): + vector_list = [float(v) for v in vector] + norm = float(np.linalg.norm(vector_list)) or 1e-8 + chunks.append(DocumentContextChunk( + index=index, + text=chunk_info["text"], + section=chunk_info["section"], + source_type=chunk_info["source_type"], + metadata=chunk_info["metadata"], + embedding=vector_list, + embedding_norm=norm + )) + + logger.info(f"上下文构建完成: {len(chunks)} 个片段") + return DocumentContext( + source_file=file_path, + embedding_model=self.settings.embedding_model, + chunks=chunks + ) + + def _collect_segments(self, document: Document) -> List[Dict[str, Any]]: + segments: List[Dict[str, Any]] = [] + current_section = "全文" + + for block in self._iter_block_items(document): + if isinstance(block, Paragraph): + text = block.text.strip() + if not text: + continue + + if self._is_heading(block): + current_section = text + segments.append({"text": text, "section": current_section, "type": "heading"}) + else: + segments.append({"text": text, "section": current_section, "type": "paragraph"}) + + elif isinstance(block, Table): + table_text = self.parser.extract_table_text(block) + if table_text: + segments.append({"text": table_text, "section": current_section, "type": "table"}) + + return segments + + def _iter_block_items(self, document: Document): + body = document.element.body + for child in body.iterchildren(): + if isinstance(child, CT_P): + yield Paragraph(child, document) + elif isinstance(child, CT_Tbl): + yield Table(child, document) + + def _is_heading(self, paragraph: Paragraph) -> bool: + style_name = getattr(paragraph.style, "name", "") or "" + return style_name.startswith("Heading") or "标题" in style_name + + def _embed_texts(self, texts: List[str]) -> List[List[float]]: + try: + vectors = self.embedding_function(texts) + return [list(vec) for vec in vectors] + except Exception as exc: + logger.error(f"计算上下文嵌入失败: {exc}") + # 退化为简单的稀疏向量,避免完全失败 + return [[float(len(text))] for text in texts] + + def _create_embedding_function(self): + model_name = self.settings.embedding_model + if model_name.startswith("text-embedding-"): + return embedding_functions.OpenAIEmbeddingFunction( + api_key=self.settings.api_key, + model_name=model_name + ) + return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) + + +class DocumentContextSearcher: + """基于向量相似度的上下文检索器。""" + + def __init__(self, + context: Optional[DocumentContext], + settings=None, + embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None, + top_k: Optional[int] = None): + self.context = context or DocumentContext("", "", []) + self.settings = settings or get_settings() + self.top_k = top_k or getattr(self.settings, "document_context_top_k", 3) + self.embedding_function = embedding_fn or self._create_embedding_function() + + self._chunk_matrix = None + self._normalized_matrix = None + self._prepare_matrix() + + def _prepare_matrix(self) -> None: + if self.context.is_empty(): + return + + embeddings = np.array([chunk.embedding for chunk in self.context.chunks], dtype=np.float32) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1e-8 + self._chunk_matrix = embeddings + self._normalized_matrix = embeddings / norms + + def search(self, query: str, top_k: Optional[int] = None) -> List[DocumentContextMatch]: + if not query or self._normalized_matrix is None: + return [] + + k = top_k or self.top_k + query_vector = self._embed_query(query) + if query_vector is None: + return [] + + query_norm = np.linalg.norm(query_vector) + if query_norm == 0: + return [] + normalized_query = query_vector / query_norm + similarities = self._normalized_matrix @ normalized_query + + top_indices = similarities.argsort()[-k:][::-1] + matches: List[DocumentContextMatch] = [] + for idx in top_indices: + score = float(similarities[idx]) + if score <= 0: + continue + chunk = self.context.chunks[int(idx)] + matches.append(DocumentContextMatch( + text=chunk.text, + section=chunk.section, + score=score, + source_type=chunk.source_type, + metadata=chunk.metadata + )) + + return matches + + def summarize_themes(self, max_items: int = 5) -> str: + if self.context.is_empty(): + return "(暂无主题线索)" + + section_map: Dict[str, List[str]] = {} + for chunk in self.context.chunks: + section = chunk.section or "未命名章节" + section_map.setdefault(section, []).append(chunk.text.strip()) + + sorted_sections = sorted(section_map.items(), key=lambda item: len(item[1]), reverse=True) + lines = [] + for section, excerpts in sorted_sections[:max_items]: + preview = excerpts[0].replace("\n", " ")[:120] + lines.append(f"- {section}: {preview}") + + return "\n".join(lines) if lines else "(暂无主题线索)" + + def _embed_query(self, query: str) -> Optional[np.ndarray]: + try: + vector = self.embedding_function([query])[0] + return np.array(vector, dtype=np.float32) + except Exception as exc: + logger.error(f"上下文检索计算查询嵌入失败: {exc}") + return None + + def _create_embedding_function(self): + model_name = getattr(self.context, "embedding_model", None) or self.settings.embedding_model + if model_name.startswith("text-embedding-"): + return embedding_functions.OpenAIEmbeddingFunction( + api_key=self.settings.api_key, + model_name=model_name + ) + return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name) diff --git a/tests/unit/test_document_context.py b/tests/unit/test_document_context.py new file mode 100644 index 0000000..50e03ec --- /dev/null +++ b/tests/unit/test_document_context.py @@ -0,0 +1,70 @@ +from pathlib import Path + +from docx import Document + +from bidmaster.utils.document_context import ( + DocumentContextBuilder, + DocumentContextSearcher, +) + + +def _create_sample_doc(doc_path: Path) -> None: + doc = Document() + doc.add_heading("第一章 项目总体概述", level=1) + doc.add_paragraph("本项目聚焦城市智慧照明系统建设,强调云边协同与多维感知能力。") + doc.add_heading("第二章 建设目标", level=1) + doc.add_paragraph("目标包括统一管控平台、智能终端、数据中台三大部分。") + + table = doc.add_table(rows=2, cols=2) + table.rows[0].cells[0].text = "指标" + table.rows[0].cells[1].text = "要求" + table.rows[1].cells[0].text = "系统稳定性" + table.rows[1].cells[1].text = "7x24小时无故障运行" + + doc.save(doc_path) + + +def _dummy_embedding(texts): + return [[float(len(text))] for text in texts] + + +def test_document_context_builder_creates_chunks(tmp_path): + doc_path = tmp_path / "context.docx" + _create_sample_doc(doc_path) + + builder = DocumentContextBuilder( + chunk_size=120, + chunk_overlap=10, + embedding_fn=_dummy_embedding, + ) + + context = builder.build(str(doc_path)) + + assert not context.is_empty() + assert all(chunk.embedding for chunk in context.chunks) + assert any("项目总体概述" in chunk.section for chunk in context.chunks) + + +def test_document_context_searcher_returns_matches(tmp_path): + doc_path = tmp_path / "search.docx" + _create_sample_doc(doc_path) + + builder = DocumentContextBuilder( + chunk_size=80, + chunk_overlap=10, + embedding_fn=_dummy_embedding, + ) + context = builder.build(str(doc_path)) + + searcher = DocumentContextSearcher( + context, + embedding_fn=_dummy_embedding, + top_k=2, + ) + + matches = searcher.search("智慧照明平台") + assert matches + assert matches[0].score > 0 + + themes = searcher.summarize_themes() + assert "第一章" in themes diff --git a/tests/unit/test_toc_generation.py b/tests/unit/test_toc_generation.py new file mode 100644 index 0000000..ba9af99 --- /dev/null +++ b/tests/unit/test_toc_generation.py @@ -0,0 +1,140 @@ +from bidmaster.nodes.base import NodeContext +from bidmaster.nodes.toc import generate_sub_chapters as gen_module +from bidmaster.nodes.toc.generate_sub_chapters import GenerateSubChaptersNode +from bidmaster.nodes.toc.user_feedback import UserFeedbackNode +from bidmaster.tools.parser import DocumentChapter, ScoringCriteria, TechnicalCategory +from bidmaster.utils.document_context import DocumentContext, DocumentContextMatch + + +def test_generate_sub_chapters_uses_context(monkeypatch): + captured = {} + + def fake_generate(criteria_list, parent_chapter, context_snippets=None): + captured["context"] = context_snippets + return [{"title": "示例小节", "level": 2, "score": 5, "children": []}] + + class DummySearcher: + def __init__(self, *args, **kwargs): + pass + + def search(self, query, top_k=None): + return [ + DocumentContextMatch( + text="系统采用云边协同架构,强调稳定性。", + section="项目概述", + score=0.95, + source_type="paragraph", + metadata={}, + ) + ] + + monkeypatch.setattr(gen_module, "DocumentContextSearcher", DummySearcher) + monkeypatch.setattr(gen_module.LLMHelper, "generate_sub_chapters_ai", staticmethod(fake_generate)) + + chapter = DocumentChapter( + id="chapter_01_technical_solution", + title="技术方案", + level=1, + score=30, + ) + criteria = [ + ScoringCriteria( + item_name="系统架构先进性", + max_score=10, + description="阐述云边端协同", + category=TechnicalCategory.TECHNICAL_SOLUTION, + chapter_id="chapter_01_technical_solution", + original_index=0, + ) + ] + + state = { + "preliminary_chapters": [chapter], + "technical_criteria": criteria, + "document_context": DocumentContext("demo.docx", "test-model", []), + } + + node = GenerateSubChaptersNode() + result = node.execute(state, NodeContext()) + + assert result["preliminary_chapters"][0].children + assert captured["context"] and "项目概述" in captured["context"][0] + + +def test_user_feedback_auto_triggers_optimization(): + chapters = [ + DocumentChapter( + id="chapter_01", + title="技术方案", + level=1, + score=30, + ) + ] + + state = { + "final_chapters": chapters, + "structure_review": { + "suggestions": [ + {"description": "补充智慧运维章节", "priority": "high", "type": "add"}, + {"description": "可选优化", "priority": "low", "type": "modify"}, + ] + }, + "auto_mode": True, + "auto_toc_max_rounds": 2, + "auto_optimization_rounds": 0, + } + + node = UserFeedbackNode() + result = node.execute(state, NodeContext()) + + assert result["needs_optimization"] is True + assert "补充" in result["user_feedback"] + assert result["auto_optimization_rounds"] == 1 + assert result["pending_suggestions"] == [] + + +def test_user_feedback_auto_skips_when_no_priority(): + chapters = [ + DocumentChapter( + id="chapter_02", + title="实施计划", + level=1, + score=20, + ) + ] + + state = { + "final_chapters": chapters, + "structure_review": {"suggestions": [{"description": "微调描述", "priority": "low", "type": "modify"}]}, + "auto_mode": True, + "auto_toc_max_rounds": 1, + } + + node = UserFeedbackNode() + result = node.execute(state, NodeContext()) + + assert result["needs_optimization"] is False + assert result["user_feedback"] == "" + + +def test_user_feedback_auto_respects_round_limit(): + chapters = [ + DocumentChapter(id="chapter_03", title="服务方案", level=1, score=10) + ] + + state = { + "final_chapters": chapters, + "structure_review": { + "suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}] + }, + "auto_mode": True, + "auto_toc_max_rounds": 1, + "auto_optimization_rounds": 1, + "pending_suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}], + } + + node = UserFeedbackNode() + result = node.execute(state, NodeContext()) + + assert result["needs_optimization"] is False + assert result["pending_suggestions"] == []