feat: default tender doc rag selection

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
sladro 2025-11-19 16:24:56 +08:00
parent c1292fcacc
commit 08db5a66bb
6 changed files with 557 additions and 58 deletions

View File

@ -3,11 +3,13 @@
负责标书章节内容的自动填写协调器模式
"""
import json
import logging
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, Optional
from .single_chapter_agent import SingleChapterAgent
from ..utils.document_context import DocumentContextBuilder, DocumentContextSearcher
logger = logging.getLogger(__name__)
@ -29,7 +31,12 @@ class ContentWriterAgent:
"""
self.interaction_handler = interaction_handler
def write_content_sync(self, word_file: str) -> Dict[str, Any]:
def write_content_sync(
self,
word_file: str,
project_dir: Optional[str] = None,
tender_file: Optional[str] = None,
) -> Dict[str, Any]:
"""同步执行内容填写
Args:
@ -54,11 +61,14 @@ class ContentWriterAgent:
init_state = {
"word_file": word_file,
"project_dir": str(project_dir) if project_dir else None,
}
state = init_node.execute(init_state, context)
chapter_queue = state["chapter_queue"]
state = self._prepare_rag_sources(state, project_dir, tender_file)
logger.info(f"{len(chapter_queue)} 个章节待处理")
for i, ch in enumerate(chapter_queue):
logger.info(f" [{i+1}] {ch['id']} - {ch['title']} (level={ch['level']})")
@ -127,4 +137,170 @@ class ContentWriterAgent:
"""创建交互模式Agent"""
from .interaction import InteractionHandler, InteractionMode
handler = InteractionHandler(mode=InteractionMode.INTERACTIVE)
return cls(handler)
return cls(handler)
def _prepare_rag_sources(
self,
state: Dict[str, Any],
project_dir: Optional[str],
tender_file: Optional[str],
) -> Dict[str, Any]:
"""构建可用的RAG来源列表并按需加载招标文件上下文。"""
tender_path = self._resolve_tender_source(project_dir, tender_file)
tender_info: Optional[Dict[str, Any]] = None
if tender_path:
tender_info = self._build_tender_searcher(state, tender_path)
if not tender_info:
prompted_path = self._prompt_for_tender_source()
if prompted_path:
tender_info = self._build_tender_searcher(state, prompted_path)
if tender_info:
state["tender_doc_source_file"] = tender_info["file_path"]
state["tender_doc_searcher"] = tender_info["searcher"]
else:
state.pop("tender_doc_source_file", None)
state.pop("tender_doc_searcher", None)
state["available_rag_sources"] = self._compose_available_sources(state, tender_info)
return state
def _resolve_tender_source(
self,
project_dir: Optional[str],
tender_file: Optional[str],
) -> Optional[str]:
if tender_file:
path = Path(tender_file)
return str(path) if path.exists() else None
if not project_dir:
return None
analysis_file = Path(project_dir) / "analysis_result.json"
if not analysis_file.exists():
return None
try:
with open(analysis_file, "r", encoding="utf-8") as fp:
analysis_data = json.load(fp)
except Exception:
logger.warning("读取 analysis_result.json 失败,无法自动注入招标上下文")
return None
source_file = analysis_data.get("source_file")
if not source_file:
return None
source_path = Path(source_file)
if not source_path.exists():
logger.warning("analysis_result.json 中的招标文件不存在: %s", source_file)
return None
return str(source_path)
def _prompt_for_tender_source(self) -> Optional[str]:
if not self.interaction_handler:
return None
try:
should_provide = self.interaction_handler(
interaction_type="confirm",
prompt="未找到招标文件上下文是否手动选择一个招标文件供RAG使用",
default=False,
key="provide_tender_doc",
)
except Exception:
return None
if not should_provide:
return None
path = self.interaction_handler(
interaction_type="file_path",
prompt="请输入招标文件(.docx)路径",
default="",
key="tender_doc_path",
validation={"exists": True, "extensions": [".docx"]},
)
if not path:
return None
file_path = Path(path)
return str(file_path) if file_path.exists() else None
def _build_tender_searcher(
self,
state: Dict[str, Any],
tender_path: str,
) -> Optional[Dict[str, Any]]:
try:
builder = DocumentContextBuilder()
document_context = builder.build(tender_path)
if document_context.is_empty():
logger.warning("招标文件未生成有效上下文: %s", tender_path)
return None
searcher = DocumentContextSearcher(document_context)
logger.info(
"已载入招标文件上下文: %s (%s段)",
Path(tender_path).name,
len(document_context.chunks),
)
return {
"file_path": tender_path,
"chunk_count": len(document_context.chunks),
"searcher": searcher,
}
except Exception as exc:
logger.warning("构建招标文件上下文失败: %s", exc)
return None
def _compose_available_sources(
self,
state: Dict[str, Any],
tender_info: Optional[Dict[str, Any]],
) -> Dict[str, Dict[str, Any]]:
sources: Dict[str, Dict[str, Any]] = {}
if tender_info:
file_name = Path(tender_info["file_path"]).name
chunk_count = tender_info.get("chunk_count", 0)
sources["tender_doc"] = {
"id": "tender_doc",
"label": f"招标文件({file_name}{chunk_count}段)",
"available": True,
"default": True,
}
else:
sources["tender_doc"] = {
"id": "tender_doc",
"label": "招标文件(未配置)",
"available": False,
"default": False,
}
rag_tool = state.get("rag_tool")
kb_label = "知识库(未初始化)"
kb_available = False
kb_files = 0
if rag_tool:
try:
stats = rag_tool.get_stats()
kb_files = stats.get("total_files", 0)
chunks = stats.get("total_chunks", 0)
kb_label = f"知识库({kb_files}个文档,{chunks}段)"
kb_available = True
except Exception:
kb_label = "知识库(统计失败)"
sources["global_kb"] = {
"id": "global_kb",
"label": kb_label,
"available": kb_available,
"default": not tender_info,
}
return sources

View File

@ -36,7 +36,13 @@ def write():
default="interactive",
help="交互模式interactive(交互)|silent(静默)",
)
def start(project_dir: str, mode: str):
@click.option(
"--tender-file",
"-t",
type=click.Path(exists=True, dir_okay=False),
help="用于RAG的招标文件(.docx)",
)
def start(project_dir: str, mode: str, tender_file: str | None = None):
"""开始填写标书内容
PROJECT_DIR: 项目目录路径默认当前目录
@ -58,7 +64,8 @@ def start(project_dir: str, mode: str):
console.print(f"❌ 项目目录不存在: {project_dir}", style="red")
return
console.print(f"📁 项目目录: {project_path.absolute()}", style="dim")
project_abs_path = project_path.resolve()
console.print(f"📁 项目目录: {project_abs_path}", style="dim")
# 自动查找Word文件
word_file = _find_word_file(project_path)
@ -68,6 +75,9 @@ def start(project_dir: str, mode: str):
console.print(f"📄 找到Word文档: {Path(word_file).name}", style="green")
console.print("📋 将直接从Word文档提取章节结构", style="dim")
# 规范化可选招标文件路径
tender_file_path = str(Path(tender_file).resolve()) if tender_file else None
# 创建Agent
if mode == "silent":
agent = ContentWriterAgent.create_silent()
@ -91,11 +101,19 @@ def start(project_dir: str, mode: str):
console=console,
) as progress:
task = progress.add_task("正在处理章节...", total=None)
result = agent.write_content_sync(word_file=word_file)
result = agent.write_content_sync(
word_file=word_file,
project_dir=str(project_abs_path),
tender_file=tender_file_path,
)
progress.update(task, completed=True)
else:
# 交互模式不使用Progress避免覆盖用户输入
result = agent.write_content_sync(word_file=word_file)
result = agent.write_content_sync(
word_file=word_file,
project_dir=str(project_abs_path),
tender_file=tender_file_path,
)
# 显示结果
if result.get("success"):

View File

@ -83,6 +83,7 @@ class Settings(BaseSettings):
# Word填充配置
max_sub_chapter_level: int = Field(default=3, description="子章节最大层级")
rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量")
tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量")
parent_context_length: int = Field(default=500, description="父章节上下文长度限制")
# Word格式配置

View File

@ -4,7 +4,7 @@
"""
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from ..base import BaseNode, NodeContext
from ...config.settings import get_settings
@ -115,8 +115,9 @@ class GenerateContentNode(BaseNode):
"""
return {
"emphasis": "",
"rag_enabled": False,
"rag_store": None
"rag_enabled": True,
"rag_store": "tender_doc",
"rag_sources": ["tender_doc"],
}
def _find_chapter(self, state: Dict[str, Any], chapter_id: str) -> Optional[Dict[str, Any]]:
@ -193,9 +194,8 @@ class GenerateContentNode(BaseNode):
prompt_spec = planner.build_prompt_spec(chapter)
generation_context = dict(prompt_spec)
# 如果启用RAG添加上下文信息
# 如果启用RAG按来源聚合上下文
if config.get("rag_enabled"):
# 检索相关内容
query_fragments = [chapter["title"]]
if prompt_spec.get("emphasis"):
query_fragments.append(prompt_spec["emphasis"])
@ -203,14 +203,22 @@ class GenerateContentNode(BaseNode):
query_fragments.append(prompt_spec["requirements_summary"][:120])
query = " ".join(fragment for fragment in query_fragments if fragment)
search_results = rag_tool.search(query, k=settings.rag_search_top_k)
if search_results:
relevant_context = "\n\n".join([r["content"] for r in search_results])
generation_context["rag_context"] = relevant_context
logger.info(f"RAG检索到 {len(search_results)} 条相关内容")
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
config["rag_sources"] = rag_sources
aggregated_context: List[str] = []
for source_id in rag_sources:
if source_id == "tender_doc":
aggregated_context.extend(self._search_tender_doc(state, query))
elif source_id == "global_kb":
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
else:
logger.warning("未知的RAG来源: %s", source_id)
if aggregated_context:
generation_context["rag_context"] = "\n\n".join(aggregated_context)
else:
logger.warning("RAG未检索到相关内容")
generation_context["rag_context"] = ""
else:
generation_context["rag_context"] = ""
@ -221,4 +229,69 @@ class GenerateContentNode(BaseNode):
return content
except Exception as e:
logger.error(f"内容生成失败: {e}", exc_info=True)
raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e
raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e
def _default_rag_sources(self, state: Dict[str, Any]) -> List[str]:
sources = state.get("available_rag_sources") or {}
preferred: List[str] = []
if sources.get("tender_doc", {}).get("available"):
preferred.append("tender_doc")
if sources.get("global_kb", {}).get("available"):
preferred.append("global_kb")
if preferred:
return preferred
if state.get("rag_tool"):
return ["global_kb"]
return []
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]:
searcher = state.get("tender_doc_searcher")
if not searcher:
return []
top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
try:
matches = searcher.search(query, top_k)
except Exception as exc:
logger.warning("招标文件RAG检索失败: %s", exc)
return []
contexts: List[str] = []
for match in matches:
section = getattr(match, "section", "招标文件") or "招标文件"
text = getattr(match, "text", "")
if not text:
continue
contexts.append(f"【招标文件】{section}\n{text}")
if contexts:
logger.info("招标文件RAG检索到 %s 条内容", len(contexts))
else:
logger.warning("招标文件RAG未检索到相关内容")
return contexts
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]:
if not rag_tool:
return []
try:
results = rag_tool.search(query, k=settings.rag_search_top_k)
except Exception as exc:
logger.warning("知识库RAG检索失败: %s", exc)
return []
contexts: List[str] = []
for result in results or []:
content = result.get("content")
if not content:
continue
source = (result.get("metadata") or {}).get("source") or "知识库片段"
contexts.append(f"【知识库】{source}\n{content}")
if contexts:
logger.info("知识库RAG检索到 %s 条内容", len(contexts))
else:
logger.warning("知识库RAG未检索到相关内容")
return contexts

View File

@ -4,7 +4,7 @@
"""
import logging
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from ..base import BaseNode, NodeContext
@ -80,17 +80,24 @@ class InteractWithUserNode(BaseNode):
)
use_rag = use_rag_response == ""
rag_store = None
rag_sources: list[str] = []
rag_store: Optional[str] = None
# 3. 如果使用RAG选择知识库
if use_rag:
rag_store = self._select_rag_store(interaction_handler, chapter_id, state)
rag_sources = self._select_rag_sources(interaction_handler, chapter_id, state)
if rag_sources:
rag_store = rag_sources[0]
else:
logger.warning("RAG已启用但没有可用的知识库来源将自动禁用")
use_rag = False
# 保存章节配置
config = {
"emphasis": emphasis.strip() if emphasis else "",
"rag_enabled": use_rag,
"rag_store": rag_store,
"rag_sources": rag_sources if use_rag else [],
}
state.setdefault("chapter_configs", {})[chapter_id] = config
@ -102,45 +109,96 @@ class InteractWithUserNode(BaseNode):
return state
def _select_rag_store(self, interaction_handler, chapter_id: str, state: Dict[str, Any]) -> str:
"""选择RAG知识库
def _select_rag_sources(
self,
interaction_handler,
chapter_id: str,
state: Dict[str, Any],
) -> list[str]:
available_sources = self._get_available_sources(state)
if not available_sources:
return []
Args:
interaction_handler: 交互处理器
chapter_id: 章节ID
state: 当前状态
ordered_ids = self._order_sources(available_sources)
selected: List[str] = []
Returns:
知识库标识
"""
# 从state获取RAGTool实例由InitConfigNode统一初始化
rag_tool = state.get("rag_tool")
if not rag_tool:
raise ValueError("RAGTool未初始化请检查InitConfigNode配置")
try:
stats = rag_tool.get_stats()
total_chunks = stats.get("total_chunks", 0)
total_files = stats.get("total_files", 0)
# 显示知识库统计信息
store_info = f"默认知识库 ({total_files}个文档, {total_chunks}个文档块)"
choice = interaction_handler(
if "tender_doc" in available_sources and available_sources["tender_doc"].get("available"):
# 默认启用招标文件,不再询问主来源
selected.append("tender_doc")
remaining = [sid for sid in ordered_ids if sid != "tender_doc"]
else:
# 无招标文件时,仍需选择主来源
default_source = ordered_ids[0]
options = [
(source_id, available_sources[source_id]["label"])
for source_id in ordered_ids
]
primary_choice = interaction_handler(
interaction_type="choice",
prompt="选择知识库",
options=[
("default", store_info),
("none", "不使用RAG")
],
default="default",
key=f"rag_store_{chapter_id}",
prompt="选择主要RAG来源",
options=options,
default=default_source,
key=f"rag_primary_{chapter_id}",
)
if primary_choice not in available_sources:
primary_choice = default_source
selected.append(primary_choice)
remaining = [sid for sid in ordered_ids if sid not in selected]
loop_index = 0
while remaining:
add_more = interaction_handler(
interaction_type="confirm",
prompt="是否追加其它RAG来源",
default=False,
key=f"rag_additional_confirm_{chapter_id}_{loop_index}",
)
if not add_more:
break
add_options = [
(source_id, available_sources[source_id]["label"])
for source_id in remaining
]
next_choice = interaction_handler(
interaction_type="choice",
prompt="选择需要追加的RAG来源",
options=add_options,
default=remaining[0],
key=f"rag_additional_choice_{chapter_id}_{loop_index}",
)
if choice == "default":
return "default"
return None
if next_choice in remaining:
selected.append(next_choice)
remaining.remove(next_choice)
loop_index += 1
except Exception as e:
logger.error(f"获取知识库信息失败: {e}", exc_info=True)
raise ValueError("获取知识库信息失败无法选择RAG存储") from e
return selected
def _get_available_sources(self, state: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
sources = state.get("available_rag_sources") or {}
usable = {sid: meta for sid, meta in sources.items() if meta.get("available")}
if usable:
return usable
if state.get("rag_tool"):
return {
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": True,
}
}
return {}
def _order_sources(self, sources: Dict[str, Dict[str, Any]]) -> list[str]:
return sorted(
sources.keys(),
key=lambda sid: (
not sources[sid].get("default", False),
sources[sid].get("label", sid),
),
)

View File

@ -0,0 +1,173 @@
"""Tests for multi-source RAG selection and aggregation."""
from __future__ import annotations
from typing import Any, Dict
from bidmaster.nodes.base import NodeContext
from bidmaster.nodes.content.generate_content import GenerateContentNode
from bidmaster.nodes.content.interact_user import InteractWithUserNode
from bidmaster.utils.document_context import DocumentContextMatch
class StubInteractionHandler:
def __init__(self, responses: Dict[str, Any]):
self.responses = responses
def __call__(self, interaction_type: str, **kwargs):
key = kwargs.get("key")
if key and key in self.responses:
return self.responses[key]
return kwargs.get("default")
def test_interact_user_allows_multiple_rag_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_additional_confirm_chapter_1_0": True,
"rag_additional_choice_chapter_1_0": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"tender_doc": {
"id": "tender_doc",
"label": "招标文件",
"available": True,
"default": True,
},
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": False,
},
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["tender_doc", "global_kb"]
def test_interact_user_disables_rag_without_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is False
assert config["rag_sources"] == []
def test_interact_user_prompts_for_primary_when_no_tender_doc():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_primary_chapter_1": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": True,
}
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["global_kb"]
class DummyTenderSearcher:
def search(self, query: str, top_k: int): # pragma: no cover - simple stub
return [
DocumentContextMatch(
text="Tender snippet",
section="章节一",
score=0.9,
source_type="paragraph",
metadata={},
)
]
class DummyRagTool:
def __init__(self):
self.last_context: Dict[str, Any] | None = None
def search(self, query: str, k: int): # pragma: no cover - simple stub
return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}]
def generate_content(self, task_id: str, context: Dict[str, Any]):
self.last_context = context
return "generated"
def test_generate_content_merges_multiple_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool()
state: Dict[str, Any] = {
"current_chapter": chapter,
"chapter_queue": [chapter],
"chapter_configs": {
"chapter_1": {
"rag_enabled": True,
"rag_sources": ["tender_doc", "global_kb"],
}
},
"chapter_children_map": {},
"generated_contents": {},
"rag_tool": rag_tool,
"tender_doc_searcher": DummyTenderSearcher(),
"available_rag_sources": {
"tender_doc": {"available": True},
"global_kb": {"available": True},
},
}
node = GenerateContentNode()
node.execute(state, NodeContext())
assert state["generated_contents"]["chapter_1"] == "generated"
assert rag_tool.last_context is not None
rag_context = rag_tool.last_context.get("rag_context")
assert "【招标文件】" in rag_context
assert "【知识库】" in rag_context