diff --git a/src/bidmaster/agents/content_writer.py b/src/bidmaster/agents/content_writer.py index 1693d38..ef1e854 100644 --- a/src/bidmaster/agents/content_writer.py +++ b/src/bidmaster/agents/content_writer.py @@ -3,11 +3,13 @@ 负责标书章节内容的自动填写(协调器模式)。 """ +import json import logging from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Optional from .single_chapter_agent import SingleChapterAgent +from ..utils.document_context import DocumentContextBuilder, DocumentContextSearcher logger = logging.getLogger(__name__) @@ -29,7 +31,12 @@ class ContentWriterAgent: """ self.interaction_handler = interaction_handler - def write_content_sync(self, word_file: str) -> Dict[str, Any]: + def write_content_sync( + self, + word_file: str, + project_dir: Optional[str] = None, + tender_file: Optional[str] = None, + ) -> Dict[str, Any]: """同步执行内容填写 Args: @@ -54,11 +61,14 @@ class ContentWriterAgent: init_state = { "word_file": word_file, + "project_dir": str(project_dir) if project_dir else None, } state = init_node.execute(init_state, context) chapter_queue = state["chapter_queue"] + state = self._prepare_rag_sources(state, project_dir, tender_file) + logger.info(f"共 {len(chapter_queue)} 个章节待处理") for i, ch in enumerate(chapter_queue): logger.info(f" [{i+1}] {ch['id']} - {ch['title']} (level={ch['level']})") @@ -127,4 +137,170 @@ class ContentWriterAgent: """创建交互模式Agent""" from .interaction import InteractionHandler, InteractionMode handler = InteractionHandler(mode=InteractionMode.INTERACTIVE) - return cls(handler) \ No newline at end of file + return cls(handler) + + def _prepare_rag_sources( + self, + state: Dict[str, Any], + project_dir: Optional[str], + tender_file: Optional[str], + ) -> Dict[str, Any]: + """构建可用的RAG来源列表,并按需加载招标文件上下文。""" + + tender_path = self._resolve_tender_source(project_dir, tender_file) + tender_info: Optional[Dict[str, Any]] = None + + if tender_path: + tender_info = self._build_tender_searcher(state, tender_path) + + if not tender_info: + prompted_path = self._prompt_for_tender_source() + if prompted_path: + tender_info = self._build_tender_searcher(state, prompted_path) + + if tender_info: + state["tender_doc_source_file"] = tender_info["file_path"] + state["tender_doc_searcher"] = tender_info["searcher"] + else: + state.pop("tender_doc_source_file", None) + state.pop("tender_doc_searcher", None) + + state["available_rag_sources"] = self._compose_available_sources(state, tender_info) + return state + + def _resolve_tender_source( + self, + project_dir: Optional[str], + tender_file: Optional[str], + ) -> Optional[str]: + if tender_file: + path = Path(tender_file) + return str(path) if path.exists() else None + + if not project_dir: + return None + + analysis_file = Path(project_dir) / "analysis_result.json" + if not analysis_file.exists(): + return None + + try: + with open(analysis_file, "r", encoding="utf-8") as fp: + analysis_data = json.load(fp) + except Exception: + logger.warning("读取 analysis_result.json 失败,无法自动注入招标上下文") + return None + + source_file = analysis_data.get("source_file") + if not source_file: + return None + + source_path = Path(source_file) + if not source_path.exists(): + logger.warning("analysis_result.json 中的招标文件不存在: %s", source_file) + return None + + return str(source_path) + + def _prompt_for_tender_source(self) -> Optional[str]: + if not self.interaction_handler: + return None + + try: + should_provide = self.interaction_handler( + interaction_type="confirm", + prompt="未找到招标文件上下文,是否手动选择一个招标文件供RAG使用?", + default=False, + key="provide_tender_doc", + ) + except Exception: + return None + + if not should_provide: + return None + + path = self.interaction_handler( + interaction_type="file_path", + prompt="请输入招标文件(.docx)路径", + default="", + key="tender_doc_path", + validation={"exists": True, "extensions": [".docx"]}, + ) + if not path: + return None + file_path = Path(path) + return str(file_path) if file_path.exists() else None + + def _build_tender_searcher( + self, + state: Dict[str, Any], + tender_path: str, + ) -> Optional[Dict[str, Any]]: + try: + builder = DocumentContextBuilder() + document_context = builder.build(tender_path) + if document_context.is_empty(): + logger.warning("招标文件未生成有效上下文: %s", tender_path) + return None + + searcher = DocumentContextSearcher(document_context) + logger.info( + "已载入招标文件上下文: %s (%s段)", + Path(tender_path).name, + len(document_context.chunks), + ) + return { + "file_path": tender_path, + "chunk_count": len(document_context.chunks), + "searcher": searcher, + } + except Exception as exc: + logger.warning("构建招标文件上下文失败: %s", exc) + return None + + def _compose_available_sources( + self, + state: Dict[str, Any], + tender_info: Optional[Dict[str, Any]], + ) -> Dict[str, Dict[str, Any]]: + sources: Dict[str, Dict[str, Any]] = {} + + if tender_info: + file_name = Path(tender_info["file_path"]).name + chunk_count = tender_info.get("chunk_count", 0) + sources["tender_doc"] = { + "id": "tender_doc", + "label": f"招标文件({file_name},{chunk_count}段)", + "available": True, + "default": True, + } + else: + sources["tender_doc"] = { + "id": "tender_doc", + "label": "招标文件(未配置)", + "available": False, + "default": False, + } + + rag_tool = state.get("rag_tool") + kb_label = "知识库(未初始化)" + kb_available = False + kb_files = 0 + if rag_tool: + try: + stats = rag_tool.get_stats() + kb_files = stats.get("total_files", 0) + chunks = stats.get("total_chunks", 0) + kb_label = f"知识库({kb_files}个文档,{chunks}段)" + kb_available = True + except Exception: + kb_label = "知识库(统计失败)" + + sources["global_kb"] = { + "id": "global_kb", + "label": kb_label, + "available": kb_available, + "default": not tender_info, + } + + return sources \ No newline at end of file diff --git a/src/bidmaster/cli/write.py b/src/bidmaster/cli/write.py index e5c7449..d0cfa24 100644 --- a/src/bidmaster/cli/write.py +++ b/src/bidmaster/cli/write.py @@ -36,7 +36,13 @@ def write(): default="interactive", help="交互模式:interactive(交互)|silent(静默)", ) -def start(project_dir: str, mode: str): +@click.option( + "--tender-file", + "-t", + type=click.Path(exists=True, dir_okay=False), + help="用于RAG的招标文件(.docx)", +) +def start(project_dir: str, mode: str, tender_file: str | None = None): """开始填写标书内容 PROJECT_DIR: 项目目录路径(默认当前目录) @@ -58,7 +64,8 @@ def start(project_dir: str, mode: str): console.print(f"❌ 项目目录不存在: {project_dir}", style="red") return - console.print(f"📁 项目目录: {project_path.absolute()}", style="dim") + project_abs_path = project_path.resolve() + console.print(f"📁 项目目录: {project_abs_path}", style="dim") # 自动查找Word文件 word_file = _find_word_file(project_path) @@ -68,6 +75,9 @@ def start(project_dir: str, mode: str): console.print(f"📄 找到Word文档: {Path(word_file).name}", style="green") console.print("📋 将直接从Word文档提取章节结构", style="dim") + # 规范化可选招标文件路径 + tender_file_path = str(Path(tender_file).resolve()) if tender_file else None + # 创建Agent if mode == "silent": agent = ContentWriterAgent.create_silent() @@ -91,11 +101,19 @@ def start(project_dir: str, mode: str): console=console, ) as progress: task = progress.add_task("正在处理章节...", total=None) - result = agent.write_content_sync(word_file=word_file) + result = agent.write_content_sync( + word_file=word_file, + project_dir=str(project_abs_path), + tender_file=tender_file_path, + ) progress.update(task, completed=True) else: # 交互模式:不使用Progress,避免覆盖用户输入 - result = agent.write_content_sync(word_file=word_file) + result = agent.write_content_sync( + word_file=word_file, + project_dir=str(project_abs_path), + tender_file=tender_file_path, + ) # 显示结果 if result.get("success"): diff --git a/src/bidmaster/config/settings.py b/src/bidmaster/config/settings.py index 79b1b26..da6462f 100644 --- a/src/bidmaster/config/settings.py +++ b/src/bidmaster/config/settings.py @@ -83,6 +83,7 @@ class Settings(BaseSettings): # Word填充配置 max_sub_chapter_level: int = Field(default=3, description="子章节最大层级") rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量") + tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量") parent_context_length: int = Field(default=500, description="父章节上下文长度限制") # Word格式配置 diff --git a/src/bidmaster/nodes/content/generate_content.py b/src/bidmaster/nodes/content/generate_content.py index 8b9a042..7993cb3 100644 --- a/src/bidmaster/nodes/content/generate_content.py +++ b/src/bidmaster/nodes/content/generate_content.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from ..base import BaseNode, NodeContext from ...config.settings import get_settings @@ -115,8 +115,9 @@ class GenerateContentNode(BaseNode): """ return { "emphasis": "", - "rag_enabled": False, - "rag_store": None + "rag_enabled": True, + "rag_store": "tender_doc", + "rag_sources": ["tender_doc"], } def _find_chapter(self, state: Dict[str, Any], chapter_id: str) -> Optional[Dict[str, Any]]: @@ -193,9 +194,8 @@ class GenerateContentNode(BaseNode): prompt_spec = planner.build_prompt_spec(chapter) generation_context = dict(prompt_spec) - # 如果启用RAG,添加上下文信息 + # 如果启用RAG,按来源聚合上下文 if config.get("rag_enabled"): - # 检索相关内容 query_fragments = [chapter["title"]] if prompt_spec.get("emphasis"): query_fragments.append(prompt_spec["emphasis"]) @@ -203,14 +203,22 @@ class GenerateContentNode(BaseNode): query_fragments.append(prompt_spec["requirements_summary"][:120]) query = " ".join(fragment for fragment in query_fragments if fragment) - search_results = rag_tool.search(query, k=settings.rag_search_top_k) - if search_results: - relevant_context = "\n\n".join([r["content"] for r in search_results]) - generation_context["rag_context"] = relevant_context - logger.info(f"RAG检索到 {len(search_results)} 条相关内容") + rag_sources = config.get("rag_sources") or self._default_rag_sources(state) + config["rag_sources"] = rag_sources + + aggregated_context: List[str] = [] + for source_id in rag_sources: + if source_id == "tender_doc": + aggregated_context.extend(self._search_tender_doc(state, query)) + elif source_id == "global_kb": + aggregated_context.extend(self._search_global_kb(state, rag_tool, query)) + else: + logger.warning("未知的RAG来源: %s", source_id) + + if aggregated_context: + generation_context["rag_context"] = "\n\n".join(aggregated_context) else: - logger.warning("RAG未检索到相关内容") generation_context["rag_context"] = "" else: generation_context["rag_context"] = "" @@ -221,4 +229,69 @@ class GenerateContentNode(BaseNode): return content except Exception as e: logger.error(f"内容生成失败: {e}", exc_info=True) - raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e \ No newline at end of file + raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e + + def _default_rag_sources(self, state: Dict[str, Any]) -> List[str]: + sources = state.get("available_rag_sources") or {} + preferred: List[str] = [] + if sources.get("tender_doc", {}).get("available"): + preferred.append("tender_doc") + if sources.get("global_kb", {}).get("available"): + preferred.append("global_kb") + if preferred: + return preferred + if state.get("rag_tool"): + return ["global_kb"] + return [] + + def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]: + searcher = state.get("tender_doc_searcher") + if not searcher: + return [] + + top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k) + try: + matches = searcher.search(query, top_k) + except Exception as exc: + logger.warning("招标文件RAG检索失败: %s", exc) + return [] + + contexts: List[str] = [] + for match in matches: + section = getattr(match, "section", "招标文件") or "招标文件" + text = getattr(match, "text", "") + if not text: + continue + contexts.append(f"【招标文件】{section}\n{text}") + + if contexts: + logger.info("招标文件RAG检索到 %s 条内容", len(contexts)) + else: + logger.warning("招标文件RAG未检索到相关内容") + + return contexts + + def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]: + if not rag_tool: + return [] + + try: + results = rag_tool.search(query, k=settings.rag_search_top_k) + except Exception as exc: + logger.warning("知识库RAG检索失败: %s", exc) + return [] + + contexts: List[str] = [] + for result in results or []: + content = result.get("content") + if not content: + continue + source = (result.get("metadata") or {}).get("source") or "知识库片段" + contexts.append(f"【知识库】{source}\n{content}") + + if contexts: + logger.info("知识库RAG检索到 %s 条内容", len(contexts)) + else: + logger.warning("知识库RAG未检索到相关内容") + + return contexts \ No newline at end of file diff --git a/src/bidmaster/nodes/content/interact_user.py b/src/bidmaster/nodes/content/interact_user.py index 9d452df..8408cee 100644 --- a/src/bidmaster/nodes/content/interact_user.py +++ b/src/bidmaster/nodes/content/interact_user.py @@ -4,7 +4,7 @@ """ import logging -from typing import Any, Dict +from typing import Any, Dict, List, Optional from ..base import BaseNode, NodeContext @@ -80,17 +80,24 @@ class InteractWithUserNode(BaseNode): ) use_rag = use_rag_response == "是" - rag_store = None + rag_sources: list[str] = [] + rag_store: Optional[str] = None # 3. 如果使用RAG,选择知识库 if use_rag: - rag_store = self._select_rag_store(interaction_handler, chapter_id, state) + rag_sources = self._select_rag_sources(interaction_handler, chapter_id, state) + if rag_sources: + rag_store = rag_sources[0] + else: + logger.warning("RAG已启用但没有可用的知识库来源,将自动禁用") + use_rag = False # 保存章节配置 config = { "emphasis": emphasis.strip() if emphasis else "", "rag_enabled": use_rag, "rag_store": rag_store, + "rag_sources": rag_sources if use_rag else [], } state.setdefault("chapter_configs", {})[chapter_id] = config @@ -102,45 +109,96 @@ class InteractWithUserNode(BaseNode): return state - def _select_rag_store(self, interaction_handler, chapter_id: str, state: Dict[str, Any]) -> str: - """选择RAG知识库 + def _select_rag_sources( + self, + interaction_handler, + chapter_id: str, + state: Dict[str, Any], + ) -> list[str]: + available_sources = self._get_available_sources(state) + if not available_sources: + return [] - Args: - interaction_handler: 交互处理器 - chapter_id: 章节ID - state: 当前状态 + ordered_ids = self._order_sources(available_sources) + selected: List[str] = [] - Returns: - 知识库标识 - """ - # 从state获取RAGTool实例(由InitConfigNode统一初始化) - rag_tool = state.get("rag_tool") - if not rag_tool: - raise ValueError("RAGTool未初始化,请检查InitConfigNode配置") - - try: - stats = rag_tool.get_stats() - total_chunks = stats.get("total_chunks", 0) - total_files = stats.get("total_files", 0) - - # 显示知识库统计信息 - store_info = f"默认知识库 ({total_files}个文档, {total_chunks}个文档块)" - - choice = interaction_handler( + if "tender_doc" in available_sources and available_sources["tender_doc"].get("available"): + # 默认启用招标文件,不再询问主来源 + selected.append("tender_doc") + remaining = [sid for sid in ordered_ids if sid != "tender_doc"] + else: + # 无招标文件时,仍需选择主来源 + default_source = ordered_ids[0] + options = [ + (source_id, available_sources[source_id]["label"]) + for source_id in ordered_ids + ] + primary_choice = interaction_handler( interaction_type="choice", - prompt="选择知识库", - options=[ - ("default", store_info), - ("none", "不使用RAG") - ], - default="default", - key=f"rag_store_{chapter_id}", + prompt="选择主要RAG来源", + options=options, + default=default_source, + key=f"rag_primary_{chapter_id}", + ) + if primary_choice not in available_sources: + primary_choice = default_source + selected.append(primary_choice) + remaining = [sid for sid in ordered_ids if sid not in selected] + + loop_index = 0 + while remaining: + add_more = interaction_handler( + interaction_type="confirm", + prompt="是否追加其它RAG来源?", + default=False, + key=f"rag_additional_confirm_{chapter_id}_{loop_index}", + ) + if not add_more: + break + + add_options = [ + (source_id, available_sources[source_id]["label"]) + for source_id in remaining + ] + next_choice = interaction_handler( + interaction_type="choice", + prompt="选择需要追加的RAG来源", + options=add_options, + default=remaining[0], + key=f"rag_additional_choice_{chapter_id}_{loop_index}", ) - if choice == "default": - return "default" - return None + if next_choice in remaining: + selected.append(next_choice) + remaining.remove(next_choice) + loop_index += 1 - except Exception as e: - logger.error(f"获取知识库信息失败: {e}", exc_info=True) - raise ValueError("获取知识库信息失败,无法选择RAG存储") from e \ No newline at end of file + return selected + + def _get_available_sources(self, state: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: + sources = state.get("available_rag_sources") or {} + usable = {sid: meta for sid, meta in sources.items() if meta.get("available")} + + if usable: + return usable + + if state.get("rag_tool"): + return { + "global_kb": { + "id": "global_kb", + "label": "知识库", + "available": True, + "default": True, + } + } + + return {} + + def _order_sources(self, sources: Dict[str, Dict[str, Any]]) -> list[str]: + return sorted( + sources.keys(), + key=lambda sid: ( + not sources[sid].get("default", False), + sources[sid].get("label", sid), + ), + ) \ No newline at end of file diff --git a/tests/unit/test_rag_selection.py b/tests/unit/test_rag_selection.py new file mode 100644 index 0000000..9ffdcba --- /dev/null +++ b/tests/unit/test_rag_selection.py @@ -0,0 +1,173 @@ +"""Tests for multi-source RAG selection and aggregation.""" + +from __future__ import annotations + +from typing import Any, Dict + +from bidmaster.nodes.base import NodeContext +from bidmaster.nodes.content.generate_content import GenerateContentNode +from bidmaster.nodes.content.interact_user import InteractWithUserNode +from bidmaster.utils.document_context import DocumentContextMatch + + +class StubInteractionHandler: + def __init__(self, responses: Dict[str, Any]): + self.responses = responses + + def __call__(self, interaction_type: str, **kwargs): + key = kwargs.get("key") + if key and key in self.responses: + return self.responses[key] + return kwargs.get("default") + + +def test_interact_user_allows_multiple_rag_sources(): + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + responses = { + "emphasis_chapter_1": "", + "use_rag_chapter_1": "是", + "rag_additional_confirm_chapter_1_0": True, + "rag_additional_choice_chapter_1_0": "global_kb", + } + handler = StubInteractionHandler(responses) + + state: Dict[str, Any] = { + "needs_interaction": True, + "current_chapter": chapter, + "chapter_configs": {}, + "interaction_handler": handler, + "available_rag_sources": { + "tender_doc": { + "id": "tender_doc", + "label": "招标文件", + "available": True, + "default": True, + }, + "global_kb": { + "id": "global_kb", + "label": "知识库", + "available": True, + "default": False, + }, + }, + } + + node = InteractWithUserNode() + node.execute(state, NodeContext()) + + config = state["chapter_configs"][chapter["id"]] + assert config["rag_enabled"] is True + assert config["rag_sources"] == ["tender_doc", "global_kb"] + + +def test_interact_user_disables_rag_without_sources(): + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + responses = { + "emphasis_chapter_1": "", + "use_rag_chapter_1": "是", + } + handler = StubInteractionHandler(responses) + + state: Dict[str, Any] = { + "needs_interaction": True, + "current_chapter": chapter, + "chapter_configs": {}, + "interaction_handler": handler, + "available_rag_sources": {}, + } + + node = InteractWithUserNode() + node.execute(state, NodeContext()) + + config = state["chapter_configs"][chapter["id"]] + assert config["rag_enabled"] is False + assert config["rag_sources"] == [] + + +def test_interact_user_prompts_for_primary_when_no_tender_doc(): + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + responses = { + "emphasis_chapter_1": "", + "use_rag_chapter_1": "是", + "rag_primary_chapter_1": "global_kb", + } + handler = StubInteractionHandler(responses) + + state: Dict[str, Any] = { + "needs_interaction": True, + "current_chapter": chapter, + "chapter_configs": {}, + "interaction_handler": handler, + "available_rag_sources": { + "global_kb": { + "id": "global_kb", + "label": "知识库", + "available": True, + "default": True, + } + }, + } + + node = InteractWithUserNode() + node.execute(state, NodeContext()) + + config = state["chapter_configs"][chapter["id"]] + assert config["rag_enabled"] is True + assert config["rag_sources"] == ["global_kb"] + + +class DummyTenderSearcher: + def search(self, query: str, top_k: int): # pragma: no cover - simple stub + return [ + DocumentContextMatch( + text="Tender snippet", + section="章节一", + score=0.9, + source_type="paragraph", + metadata={}, + ) + ] + + +class DummyRagTool: + def __init__(self): + self.last_context: Dict[str, Any] | None = None + + def search(self, query: str, k: int): # pragma: no cover - simple stub + return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}] + + def generate_content(self, task_id: str, context: Dict[str, Any]): + self.last_context = context + return "generated" + + +def test_generate_content_merges_multiple_sources(): + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + rag_tool = DummyRagTool() + state: Dict[str, Any] = { + "current_chapter": chapter, + "chapter_queue": [chapter], + "chapter_configs": { + "chapter_1": { + "rag_enabled": True, + "rag_sources": ["tender_doc", "global_kb"], + } + }, + "chapter_children_map": {}, + "generated_contents": {}, + "rag_tool": rag_tool, + "tender_doc_searcher": DummyTenderSearcher(), + "available_rag_sources": { + "tender_doc": {"available": True}, + "global_kb": {"available": True}, + }, + } + + node = GenerateContentNode() + node.execute(state, NodeContext()) + + assert state["generated_contents"]["chapter_1"] == "generated" + assert rag_tool.last_context is not None + rag_context = rag_tool.last_context.get("rag_context") + assert "【招标文件】" in rag_context + assert "【知识库】" in rag_context