feat: add context-aware toc automation

This commit is contained in:
sladro 2025-11-18 17:54:44 +08:00
parent 6f785c9f2c
commit e028e4fa96
14 changed files with 712 additions and 25 deletions

View File

@ -113,11 +113,14 @@ toc_prompts:
【大类别】: {parent_title}
【评分项】:
{criteria_info}
【招标上下文摘录】:
{context_snippets}
生成要求:
1. 为每个评分项生成对应的子标题名称(不要包含编号)
2. 重要评分项可添加三级子标题(不要包含编号)
3. 只返回标题文本,编号由Word自动管理
3. 充分参考上下文摘录中的专业术语和业务背景,体现定制化判断
4. 只返回标题文本,编号由Word自动管理
返回JSON格式:
{{
@ -137,11 +140,14 @@ toc_prompts:
【当前生成的章节结构】:
{chapters_summary}
【招标主题线索】:
{document_themes}
【审查要求】:
1. 是否缺少重要的标准章节?
2. 章节顺序是否合理?
3. 每个评分项是否都有对应章节?
3. 是否覆盖了主题线索中的关键内容?
4. 每个评分项是否都有对应章节?
返回JSON格式:
{{

View File

@ -18,6 +18,7 @@ from ..tools.parser import BidParser, BidStructure, ScoringCriteria, DeviationIt
from ..config import get_settings
from ..nodes.toc.base_mixins import WorkflowUtilsMixin
from .base import BaseAgentFactory
from ..utils.document_context import DocumentContext, DocumentContextBuilder
logger = logging.getLogger(__name__)
@ -60,9 +61,15 @@ class AnalysisAgentState(TypedDict):
# 章节生成过程
preliminary_chapters: List[DocumentChapter] # 初步生成的章节
structure_review: Dict[str, Any] # AI审查结果
document_context: DocumentContext | None # 招标文档上下文
auto_mode: bool # 是否启用自动目录优化
user_feedback: str # 自动/人工反馈内容
pending_suggestions: List[Dict[str, Any]] # 待处理的AI建议
auto_optimization_rounds: int # 自动优化迭代次数
auto_toc_max_rounds: int # 最大自动优化次数
# 最终输出
bid_structure: BidStructure
bid_structure: BidStructure | None
# 错误处理
error: str
@ -154,6 +161,14 @@ def extract_tables_node(state: AnalysisAgentState) -> AnalysisAgentState:
state["current_step"] = "extract_tables"
state["progress"] = 0.25
# 构建全文上下文供后续目录生成使用
try:
context_builder = DocumentContextBuilder()
document_context = context_builder.build(source_file)
state["document_context"] = document_context
except Exception as exc:
logger.warning(f"构建招标文档上下文失败: {exc}")
except Exception as e:
logger.error(f"表格提取失败: {e}")
state["error"] = str(e)
@ -284,7 +299,9 @@ def generate_toc_with_agent_node(state: AnalysisAgentState) -> AnalysisAgentStat
result = toc_agent.generate_toc_sync(
technical_criteria=technical_criteria,
generation_mode=generation_mode,
template_file=template_file
template_file=template_file,
document_context=state.get("document_context"),
auto_mode=state.get("auto_mode")
)
# 检查是否有错误
@ -475,6 +492,12 @@ class AnalysisAgent(BaseAgentFactory):
deviation_items=[],
preliminary_chapters=[],
structure_review={},
document_context=None,
auto_mode=self.settings.auto_toc_mode,
user_feedback="",
pending_suggestions=[],
auto_optimization_rounds=0,
auto_toc_max_rounds=self.settings.auto_toc_max_rounds,
bid_structure=None,
error="",
warnings=[]
@ -482,7 +505,10 @@ class AnalysisAgent(BaseAgentFactory):
try:
# 执行LangGraph工作流
final_state = await self.graph.ainvoke(initial_state)
final_state = await self.graph.ainvoke(
initial_state,
config={"recursion_limit": self.settings.langgraph_recursion_limit}
)
# 构建执行结果
if final_state.get("error"):

View File

@ -8,6 +8,7 @@ from typing import List, Dict, Any, Optional, Callable, Tuple
from langgraph.graph import StateGraph, END
from ..nodes.base import BaseNode, NodeContext
from ..config import get_settings
logger = logging.getLogger(__name__)
@ -227,6 +228,8 @@ class BaseAgent:
"""
self.builder = builder
self.graph = builder.build()
self.settings = get_settings()
self.recursion_limit = getattr(self.settings, "langgraph_recursion_limit", 120)
async def execute(self, initial_state: Dict[str, Any]) -> Dict[str, Any]:
"""异步执行Agent
@ -240,7 +243,10 @@ class BaseAgent:
logger.info(f"开始执行Agent初始状态keys: {list(initial_state.keys())}")
try:
final_state = await self.graph.ainvoke(initial_state)
final_state = await self.graph.ainvoke(
initial_state,
config={"recursion_limit": self.recursion_limit}
)
logger.info("Agent执行完成")
return final_state
except Exception as e:

View File

@ -145,7 +145,9 @@ class TocAgent(BaseAgent, BaseAgentFactory):
async def generate_toc(self,
technical_criteria: list,
generation_mode: Optional[str] = None,
template_file: Optional[str] = None) -> Dict[str, Any]:
template_file: Optional[str] = None,
document_context: Optional[Any] = None,
auto_mode: Optional[bool] = None) -> Dict[str, Any]:
"""生成目录结构
Args:
@ -165,7 +167,10 @@ class TocAgent(BaseAgent, BaseAgentFactory):
"structure_review": {},
"final_chapters": [],
"error": "",
"warnings": []
"warnings": [],
"document_context": document_context,
"auto_mode": True if auto_mode is None else auto_mode,
"user_feedback": ""
}
# 如果提供了预设参数,添加到状态中
@ -182,7 +187,9 @@ class TocAgent(BaseAgent, BaseAgentFactory):
def generate_toc_sync(self,
technical_criteria: list,
generation_mode: Optional[str] = None,
template_file: Optional[str] = None) -> Dict[str, Any]:
template_file: Optional[str] = None,
document_context: Optional[Any] = None,
auto_mode: Optional[bool] = None) -> Dict[str, Any]:
"""同步生成目录结构
Args:
@ -195,5 +202,5 @@ class TocAgent(BaseAgent, BaseAgentFactory):
"""
import asyncio
return asyncio.run(
self.generate_toc(technical_criteria, generation_mode, template_file)
self.generate_toc(technical_criteria, generation_mode, template_file, document_context, auto_mode)
)

View File

@ -52,6 +52,11 @@ class Settings(BaseSettings):
embedding_model: str = Field(default="text-embedding-3-small", description="嵌入模型")
chunk_size: int = Field(default=1000, description="文档块大小")
chunk_overlap: int = Field(default=200, description="块重叠大小")
document_context_chunk_size: int = Field(default=800, description="目录生成上下文分块大小")
document_context_overlap: int = Field(default=150, description="目录生成上下文块重叠")
document_context_top_k: int = Field(default=3, description="上下文检索结果数量")
langgraph_recursion_limit: int = Field(default=120, description="LangGraph执行的递归上限")
auto_toc_max_rounds: int = Field(default=2, description="自动目录优化最大迭代次数")
# 性能配置
max_workers: int = Field(default=4, description="最大工作线程数")
@ -59,6 +64,7 @@ class Settings(BaseSettings):
# 交互配置
interaction_timeout: int = Field(default=60, description="用户交互超时时间(秒)")
auto_toc_mode: bool = Field(default=True, description="目录生成是否默认自动处理反馈")
# 日志配置
log_level: str = Field(default="INFO", description="日志级别")

View File

@ -4,7 +4,7 @@
"""
import logging
from typing import Dict, List, Any
from typing import Dict, List, Any, Optional
from ..base import BaseNode, NodeContext
from ...tools.parser import ScoringCriteria, DocumentChapter
@ -12,6 +12,7 @@ from .category_manager import CategoryManager
from .factories import ChapterFactory
from .llm_helper import LLMHelper
from .base_mixins import TocNodeBase
from ...utils.document_context import DocumentContextSearcher
logger = logging.getLogger(__name__)
@ -40,11 +41,23 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
preliminary_chapters = state["preliminary_chapters"]
technical_criteria = state["technical_criteria"]
document_context = state.get("document_context")
context_searcher: Optional[DocumentContextSearcher] = None
if document_context:
try:
context_searcher = DocumentContextSearcher(document_context)
except Exception as exc:
self.log_step_info("context_searcher", f"上下文检索器初始化失败: {exc}")
# 为每个章节生成子标题
enhanced_chapters = []
for chapter in preliminary_chapters:
enhanced_chapter = self._enhance_chapter_with_subs(chapter, technical_criteria)
enhanced_chapter = self._enhance_chapter_with_subs(
chapter,
technical_criteria,
context_searcher
)
enhanced_chapters.append(enhanced_chapter)
self.log_step_info("generate_sub_chapters", f"完成{len(enhanced_chapters)}个章节的子标题生成")
@ -53,7 +66,8 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
def _enhance_chapter_with_subs(self,
chapter: DocumentChapter,
technical_criteria: List[ScoringCriteria]) -> DocumentChapter:
technical_criteria: List[ScoringCriteria],
context_searcher: Optional[DocumentContextSearcher] = None) -> DocumentChapter:
"""为章节增强子标题
Args:
@ -71,8 +85,14 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
self.log_step_info("enhance_chapter", warning_msg)
return chapter
context_snippets = self._collect_context_snippets(chapter, corresponding_criteria, context_searcher)
# 使用AI生成子标题
sub_chapters_data = LLMHelper.generate_sub_chapters_ai(corresponding_criteria, chapter)
sub_chapters_data = LLMHelper.generate_sub_chapters_ai(
corresponding_criteria,
chapter,
context_snippets=context_snippets
)
if sub_chapters_data:
chapter.children = ChapterFactory.create_chapters_from_ai_response(chapter, sub_chapters_data)
@ -82,4 +102,34 @@ class GenerateSubChaptersNode(BaseNode, TocNodeBase):
self.log_step_info("enhance_chapter", f"章节 {chapter.title} AI生成子标题失败")
raise RuntimeError(error_msg)
return chapter
return chapter
def _collect_context_snippets(self,
chapter: DocumentChapter,
criteria: List[ScoringCriteria],
context_searcher: Optional[DocumentContextSearcher]) -> Optional[List[str]]:
if not context_searcher or not criteria:
return None
query_parts = [chapter.title]
for item in criteria:
query_parts.append(item.item_name)
if item.description:
query_parts.append(item.description)
query = "".join(part for part in query_parts if part)
if not query:
return None
matches = context_searcher.search(query)
if not matches:
return None
snippets: List[str] = []
for match in matches:
snippet_text = match.text.strip().replace("\n", " ")
if not snippet_text:
continue
snippets.append(f"{match.section}: {snippet_text[:200]}")
return snippets or None

View File

@ -80,12 +80,14 @@ class LLMHelper:
@staticmethod
def generate_sub_chapters_ai(criteria_list: List[ScoringCriteria],
parent_chapter: DocumentChapter) -> Optional[List[Dict[str, Any]]]:
parent_chapter: DocumentChapter,
context_snippets: Optional[List[str]] = None) -> Optional[List[Dict[str, Any]]]:
"""AI生成子章节数据
Args:
criteria_list: 对应的评分项列表
parent_chapter: 父章节
context_snippets: 招标文档的相关片段
Returns:
子章节数据列表失败时返回None
@ -97,10 +99,13 @@ class LLMHelper:
# 从配置获取提示词
prompt_manager = get_prompt_manager()
context_text = "\n".join(context_snippets) if context_snippets else "(暂无上下文摘录)"
prompt = prompt_manager.get_toc_prompt(
"generate_sub_chapters",
parent_title=parent_chapter.title,
criteria_info=chr(10).join(criteria_info)
criteria_info=chr(10).join(criteria_info),
context_snippets=context_text
)
try:
@ -117,12 +122,14 @@ class LLMHelper:
@staticmethod
def review_structure_ai(technical_criteria: List[ScoringCriteria],
preliminary_chapters: List[DocumentChapter]) -> Dict[str, Any]:
preliminary_chapters: List[DocumentChapter],
document_themes: Optional[str] = None) -> Dict[str, Any]:
"""AI审查目录结构
Args:
technical_criteria: 技术评分项
preliminary_chapters: 初步生成的章节
document_themes: 招标文档主题线索
Returns:
审查结果字典
@ -134,10 +141,13 @@ class LLMHelper:
# 从配置获取提示词
prompt_manager = get_prompt_manager()
themes_text = document_themes or "(暂无主题线索)"
review_prompt = prompt_manager.get_toc_prompt(
"review_structure",
criteria_summary=criteria_summary,
chapters_summary=chapters_summary
chapters_summary=chapters_summary,
document_themes=themes_text
)
try:

View File

@ -58,18 +58,21 @@ class OptimizeWithFeedbackNode(BaseNode, TocNodeBase):
return self._update_state(state,
adjusted_chapters=optimized_chapters,
user_feedback="", # 清空反馈
needs_optimization=False)
needs_optimization=False,
pending_suggestions=[])
else:
logger.warning("AI优化结果未按用户要求修改保持原有结构")
return self._update_state(state,
user_feedback="",
needs_optimization=False)
needs_optimization=False,
pending_suggestions=[])
else:
# 优化失败,保持原有结构
logger.warning("AI优化失败保持原有结构")
return self._update_state(state,
user_feedback="",
needs_optimization=False)
needs_optimization=False,
pending_suggestions=[])
def _optimize_with_ai(self, chapters: List[DocumentChapter], feedback: str) -> Optional[List[DocumentChapter]]:
"""使用AI优化目录结构

View File

@ -10,6 +10,7 @@ from ..base import BaseNode, NodeContext
from ...tools.parser import ScoringCriteria, DocumentChapter
from .llm_helper import LLMHelper
from .base_mixins import TocNodeBase
from ...utils.document_context import DocumentContextSearcher
logger = logging.getLogger(__name__)
@ -39,8 +40,21 @@ class ReviewStructureNode(BaseNode, TocNodeBase):
technical_criteria = state["technical_criteria"]
preliminary_chapters = state["preliminary_chapters"]
document_context = state.get("document_context")
document_themes = None
if document_context:
try:
theme_searcher = DocumentContextSearcher(document_context)
document_themes = theme_searcher.summarize_themes()
except Exception as exc:
self.log_step_info("review_structure", f"主题线索生成失败: {exc}")
# 执行AI审查
review_result = LLMHelper.review_structure_ai(technical_criteria, preliminary_chapters)
review_result = LLMHelper.review_structure_ai(
technical_criteria,
preliminary_chapters,
document_themes=document_themes
)
# 根据审查结果添加警告
if review_result.get("suggestions"):
@ -51,4 +65,8 @@ class ReviewStructureNode(BaseNode, TocNodeBase):
optimization_score = review_result.get('optimization_score', 'N/A')
self.log_step_info("review_structure", f"AI审查完成优化评分: {optimization_score}")
return self._update_state(state, structure_review=review_result)
return self._update_state(
state,
structure_review=review_result,
pending_suggestions=review_result.get("suggestions", [])
)

View File

@ -37,6 +37,10 @@ class UserFeedbackNode(BaseNode, TocNodeBase):
raise ValueError("缺少最终章节数据")
final_chapters = state.get("final_chapters", [])
auto_mode = state.get("auto_mode", False)
if auto_mode:
return self._handle_auto_feedback(state, final_chapters)
# 获取交互处理器
interaction_handler = state.get("interaction_handler")
@ -99,6 +103,53 @@ class UserFeedbackNode(BaseNode, TocNodeBase):
user_feedback=feedback,
needs_optimization=True)
def _handle_auto_feedback(self,
state: Dict[str, Any],
final_chapters: List[DocumentChapter]) -> Dict[str, Any]:
review = state.get("structure_review", {}) or {}
suggestions = state.get("pending_suggestions") or review.get("suggestions", []) or []
max_rounds = state.get("auto_toc_max_rounds", 1)
rounds = state.get("auto_optimization_rounds", 0)
if rounds >= max_rounds:
logger.info("自动模式: 达到最大优化次数,自动结束")
return self._update_state(state,
should_continue=False,
user_feedback="",
needs_optimization=False,
pending_suggestions=[])
prioritized = [
item for item in suggestions
if item.get("priority", "low").lower() in {"high", "medium"}
]
if not prioritized:
logger.info("自动模式: 无需额外优化,直接结束")
return self._update_state(state,
should_continue=False,
user_feedback="",
needs_optimization=False,
pending_suggestions=[])
feedback_lines = []
for idx, item in enumerate(prioritized[:5], 1):
description = item.get("description", "补充完善章节内容")
action_type = item.get("type", "adjust")
priority = item.get("priority", "medium")
feedback_lines.append(f"{idx}. [{priority}] {action_type}: {description}")
feedback_text = "\n".join(feedback_lines)
logger.info("自动模式: 根据审查建议触发目录优化")
return self._update_state(state,
should_continue=True,
user_feedback=feedback_text,
needs_optimization=True,
pending_suggestions=[],
auto_optimization_rounds=rounds + 1)
def _format_chapters_for_display(self, chapters: List[DocumentChapter], indent_level: int = 0) -> str:
"""格式化章节用于显示

View File

@ -1,5 +1,15 @@
"""工具模块"""
from .timeout_input import timeout_prompt
from .document_context import (
DocumentContext,
DocumentContextBuilder,
DocumentContextSearcher,
)
__all__ = ["timeout_prompt"]
__all__ = [
"timeout_prompt",
"DocumentContext",
"DocumentContextBuilder",
"DocumentContextSearcher",
]

View File

@ -0,0 +1,284 @@
"""招标文档上下文构建与检索工具。
提供基于招标文件原文的分块嵌入和相似度搜索能力
用于目录生成阶段的智能参考
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional
import numpy as np
from docx import Document
from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P
from docx.table import Table
from docx.text.paragraph import Paragraph
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chromadb.utils import embedding_functions
from ..config import get_settings
from ..tools.parser import BidParser
logger = logging.getLogger(__name__)
@dataclass
class DocumentContextChunk:
"""单个上下文块的数据结构。"""
index: int
text: str
section: str
source_type: str
metadata: Dict[str, Any]
embedding: List[float]
embedding_norm: float
@dataclass
class DocumentContext:
"""完整招标文档的上下文集合。"""
source_file: str
embedding_model: str
chunks: List[DocumentContextChunk] = field(default_factory=list)
def is_empty(self) -> bool:
return len(self.chunks) == 0
@dataclass
class DocumentContextMatch:
"""检索结果结构。"""
text: str
section: str
score: float
source_type: str
metadata: Dict[str, Any]
class DocumentContextBuilder:
"""构建招标文档上下文。"""
def __init__(self,
settings=None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None):
self.settings = settings or get_settings()
self.chunk_size = chunk_size or getattr(self.settings, "document_context_chunk_size", 800)
self.chunk_overlap = chunk_overlap or getattr(self.settings, "document_context_overlap", 200)
self.embedding_function = embedding_fn or self._create_embedding_function()
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
)
self.parser = BidParser()
def build(self, file_path: str) -> DocumentContext:
"""构建指定Word文件的上下文。"""
try:
document = Document(file_path)
except Exception as exc:
logger.error(f"加载Word文档失败: {exc}")
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
segments = self._collect_segments(document)
if not segments:
logger.warning("未从招标文件提取到可用文本片段")
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
raw_chunks: List[Dict[str, Any]] = []
for segment in segments:
pieces = self.text_splitter.split_text(segment["text"])
for piece in pieces:
cleaned_piece = piece.strip()
if not cleaned_piece:
continue
raw_chunks.append({
"text": cleaned_piece,
"section": segment["section"],
"source_type": segment["type"],
"metadata": {
"source_type": segment["type"],
"source_section": segment["section"],
}
})
if not raw_chunks:
logger.warning("文本分块后为空,跳过上下文构建")
return DocumentContext(source_file=file_path, embedding_model=self.settings.embedding_model)
embeddings = self._embed_texts([chunk["text"] for chunk in raw_chunks])
chunks: List[DocumentContextChunk] = []
for index, (chunk_info, vector) in enumerate(zip(raw_chunks, embeddings)):
vector_list = [float(v) for v in vector]
norm = float(np.linalg.norm(vector_list)) or 1e-8
chunks.append(DocumentContextChunk(
index=index,
text=chunk_info["text"],
section=chunk_info["section"],
source_type=chunk_info["source_type"],
metadata=chunk_info["metadata"],
embedding=vector_list,
embedding_norm=norm
))
logger.info(f"上下文构建完成: {len(chunks)} 个片段")
return DocumentContext(
source_file=file_path,
embedding_model=self.settings.embedding_model,
chunks=chunks
)
def _collect_segments(self, document: Document) -> List[Dict[str, Any]]:
segments: List[Dict[str, Any]] = []
current_section = "全文"
for block in self._iter_block_items(document):
if isinstance(block, Paragraph):
text = block.text.strip()
if not text:
continue
if self._is_heading(block):
current_section = text
segments.append({"text": text, "section": current_section, "type": "heading"})
else:
segments.append({"text": text, "section": current_section, "type": "paragraph"})
elif isinstance(block, Table):
table_text = self.parser.extract_table_text(block)
if table_text:
segments.append({"text": table_text, "section": current_section, "type": "table"})
return segments
def _iter_block_items(self, document: Document):
body = document.element.body
for child in body.iterchildren():
if isinstance(child, CT_P):
yield Paragraph(child, document)
elif isinstance(child, CT_Tbl):
yield Table(child, document)
def _is_heading(self, paragraph: Paragraph) -> bool:
style_name = getattr(paragraph.style, "name", "") or ""
return style_name.startswith("Heading") or "标题" in style_name
def _embed_texts(self, texts: List[str]) -> List[List[float]]:
try:
vectors = self.embedding_function(texts)
return [list(vec) for vec in vectors]
except Exception as exc:
logger.error(f"计算上下文嵌入失败: {exc}")
# 退化为简单的稀疏向量,避免完全失败
return [[float(len(text))] for text in texts]
def _create_embedding_function(self):
model_name = self.settings.embedding_model
if model_name.startswith("text-embedding-"):
return embedding_functions.OpenAIEmbeddingFunction(
api_key=self.settings.api_key,
model_name=model_name
)
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
class DocumentContextSearcher:
"""基于向量相似度的上下文检索器。"""
def __init__(self,
context: Optional[DocumentContext],
settings=None,
embedding_fn: Optional[Callable[[Iterable[str]], List[List[float]]]] = None,
top_k: Optional[int] = None):
self.context = context or DocumentContext("", "", [])
self.settings = settings or get_settings()
self.top_k = top_k or getattr(self.settings, "document_context_top_k", 3)
self.embedding_function = embedding_fn or self._create_embedding_function()
self._chunk_matrix = None
self._normalized_matrix = None
self._prepare_matrix()
def _prepare_matrix(self) -> None:
if self.context.is_empty():
return
embeddings = np.array([chunk.embedding for chunk in self.context.chunks], dtype=np.float32)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1e-8
self._chunk_matrix = embeddings
self._normalized_matrix = embeddings / norms
def search(self, query: str, top_k: Optional[int] = None) -> List[DocumentContextMatch]:
if not query or self._normalized_matrix is None:
return []
k = top_k or self.top_k
query_vector = self._embed_query(query)
if query_vector is None:
return []
query_norm = np.linalg.norm(query_vector)
if query_norm == 0:
return []
normalized_query = query_vector / query_norm
similarities = self._normalized_matrix @ normalized_query
top_indices = similarities.argsort()[-k:][::-1]
matches: List[DocumentContextMatch] = []
for idx in top_indices:
score = float(similarities[idx])
if score <= 0:
continue
chunk = self.context.chunks[int(idx)]
matches.append(DocumentContextMatch(
text=chunk.text,
section=chunk.section,
score=score,
source_type=chunk.source_type,
metadata=chunk.metadata
))
return matches
def summarize_themes(self, max_items: int = 5) -> str:
if self.context.is_empty():
return "(暂无主题线索)"
section_map: Dict[str, List[str]] = {}
for chunk in self.context.chunks:
section = chunk.section or "未命名章节"
section_map.setdefault(section, []).append(chunk.text.strip())
sorted_sections = sorted(section_map.items(), key=lambda item: len(item[1]), reverse=True)
lines = []
for section, excerpts in sorted_sections[:max_items]:
preview = excerpts[0].replace("\n", " ")[:120]
lines.append(f"- {section}: {preview}")
return "\n".join(lines) if lines else "(暂无主题线索)"
def _embed_query(self, query: str) -> Optional[np.ndarray]:
try:
vector = self.embedding_function([query])[0]
return np.array(vector, dtype=np.float32)
except Exception as exc:
logger.error(f"上下文检索计算查询嵌入失败: {exc}")
return None
def _create_embedding_function(self):
model_name = getattr(self.context, "embedding_model", None) or self.settings.embedding_model
if model_name.startswith("text-embedding-"):
return embedding_functions.OpenAIEmbeddingFunction(
api_key=self.settings.api_key,
model_name=model_name
)
return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)

View File

@ -0,0 +1,70 @@
from pathlib import Path
from docx import Document
from bidmaster.utils.document_context import (
DocumentContextBuilder,
DocumentContextSearcher,
)
def _create_sample_doc(doc_path: Path) -> None:
doc = Document()
doc.add_heading("第一章 项目总体概述", level=1)
doc.add_paragraph("本项目聚焦城市智慧照明系统建设,强调云边协同与多维感知能力。")
doc.add_heading("第二章 建设目标", level=1)
doc.add_paragraph("目标包括统一管控平台、智能终端、数据中台三大部分。")
table = doc.add_table(rows=2, cols=2)
table.rows[0].cells[0].text = "指标"
table.rows[0].cells[1].text = "要求"
table.rows[1].cells[0].text = "系统稳定性"
table.rows[1].cells[1].text = "7x24小时无故障运行"
doc.save(doc_path)
def _dummy_embedding(texts):
return [[float(len(text))] for text in texts]
def test_document_context_builder_creates_chunks(tmp_path):
doc_path = tmp_path / "context.docx"
_create_sample_doc(doc_path)
builder = DocumentContextBuilder(
chunk_size=120,
chunk_overlap=10,
embedding_fn=_dummy_embedding,
)
context = builder.build(str(doc_path))
assert not context.is_empty()
assert all(chunk.embedding for chunk in context.chunks)
assert any("项目总体概述" in chunk.section for chunk in context.chunks)
def test_document_context_searcher_returns_matches(tmp_path):
doc_path = tmp_path / "search.docx"
_create_sample_doc(doc_path)
builder = DocumentContextBuilder(
chunk_size=80,
chunk_overlap=10,
embedding_fn=_dummy_embedding,
)
context = builder.build(str(doc_path))
searcher = DocumentContextSearcher(
context,
embedding_fn=_dummy_embedding,
top_k=2,
)
matches = searcher.search("智慧照明平台")
assert matches
assert matches[0].score > 0
themes = searcher.summarize_themes()
assert "第一章" in themes

View File

@ -0,0 +1,140 @@
from bidmaster.nodes.base import NodeContext
from bidmaster.nodes.toc import generate_sub_chapters as gen_module
from bidmaster.nodes.toc.generate_sub_chapters import GenerateSubChaptersNode
from bidmaster.nodes.toc.user_feedback import UserFeedbackNode
from bidmaster.tools.parser import DocumentChapter, ScoringCriteria, TechnicalCategory
from bidmaster.utils.document_context import DocumentContext, DocumentContextMatch
def test_generate_sub_chapters_uses_context(monkeypatch):
captured = {}
def fake_generate(criteria_list, parent_chapter, context_snippets=None):
captured["context"] = context_snippets
return [{"title": "示例小节", "level": 2, "score": 5, "children": []}]
class DummySearcher:
def __init__(self, *args, **kwargs):
pass
def search(self, query, top_k=None):
return [
DocumentContextMatch(
text="系统采用云边协同架构,强调稳定性。",
section="项目概述",
score=0.95,
source_type="paragraph",
metadata={},
)
]
monkeypatch.setattr(gen_module, "DocumentContextSearcher", DummySearcher)
monkeypatch.setattr(gen_module.LLMHelper, "generate_sub_chapters_ai", staticmethod(fake_generate))
chapter = DocumentChapter(
id="chapter_01_technical_solution",
title="技术方案",
level=1,
score=30,
)
criteria = [
ScoringCriteria(
item_name="系统架构先进性",
max_score=10,
description="阐述云边端协同",
category=TechnicalCategory.TECHNICAL_SOLUTION,
chapter_id="chapter_01_technical_solution",
original_index=0,
)
]
state = {
"preliminary_chapters": [chapter],
"technical_criteria": criteria,
"document_context": DocumentContext("demo.docx", "test-model", []),
}
node = GenerateSubChaptersNode()
result = node.execute(state, NodeContext())
assert result["preliminary_chapters"][0].children
assert captured["context"] and "项目概述" in captured["context"][0]
def test_user_feedback_auto_triggers_optimization():
chapters = [
DocumentChapter(
id="chapter_01",
title="技术方案",
level=1,
score=30,
)
]
state = {
"final_chapters": chapters,
"structure_review": {
"suggestions": [
{"description": "补充智慧运维章节", "priority": "high", "type": "add"},
{"description": "可选优化", "priority": "low", "type": "modify"},
]
},
"auto_mode": True,
"auto_toc_max_rounds": 2,
"auto_optimization_rounds": 0,
}
node = UserFeedbackNode()
result = node.execute(state, NodeContext())
assert result["needs_optimization"] is True
assert "补充" in result["user_feedback"]
assert result["auto_optimization_rounds"] == 1
assert result["pending_suggestions"] == []
def test_user_feedback_auto_skips_when_no_priority():
chapters = [
DocumentChapter(
id="chapter_02",
title="实施计划",
level=1,
score=20,
)
]
state = {
"final_chapters": chapters,
"structure_review": {"suggestions": [{"description": "微调描述", "priority": "low", "type": "modify"}]},
"auto_mode": True,
"auto_toc_max_rounds": 1,
}
node = UserFeedbackNode()
result = node.execute(state, NodeContext())
assert result["needs_optimization"] is False
assert result["user_feedback"] == ""
def test_user_feedback_auto_respects_round_limit():
chapters = [
DocumentChapter(id="chapter_03", title="服务方案", level=1, score=10)
]
state = {
"final_chapters": chapters,
"structure_review": {
"suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}]
},
"auto_mode": True,
"auto_toc_max_rounds": 1,
"auto_optimization_rounds": 1,
"pending_suggestions": [{"description": "增加服务考核", "priority": "high", "type": "add"}],
}
node = UserFeedbackNode()
result = node.execute(state, NodeContext())
assert result["needs_optimization"] is False
assert result["pending_suggestions"] == []