"""Tests for multi-source RAG selection and aggregation.""" from __future__ import annotations from typing import Any, Dict from bidmaster.nodes.base import NodeContext from bidmaster.nodes.content.generate_content import GenerateContentNode from bidmaster.nodes.content.interact_user import InteractWithUserNode from bidmaster.utils.document_context import DocumentContextMatch class StubInteractionHandler: def __init__(self, responses: Dict[str, Any]): self.responses = responses def __call__(self, interaction_type: str, **kwargs): key = kwargs.get("key") if key and key in self.responses: return self.responses[key] return kwargs.get("default") def test_interact_user_allows_multiple_rag_sources(): chapter = {"id": "chapter_1", "title": "概述", "level": 1} responses = { "emphasis_chapter_1": "", "use_rag_chapter_1": "是", "rag_additional_confirm_chapter_1_0": True, "rag_additional_choice_chapter_1_0": "global_kb", } handler = StubInteractionHandler(responses) state: Dict[str, Any] = { "needs_interaction": True, "current_chapter": chapter, "chapter_configs": {}, "interaction_handler": handler, "available_rag_sources": { "tender_doc": { "id": "tender_doc", "label": "招标文件", "available": True, "default": True, }, "global_kb": { "id": "global_kb", "label": "知识库", "available": True, "default": False, }, }, } node = InteractWithUserNode() node.execute(state, NodeContext()) config = state["chapter_configs"][chapter["id"]] assert config["rag_enabled"] is True assert config["rag_sources"] == ["tender_doc", "global_kb"] def test_interact_user_disables_rag_without_sources(): chapter = {"id": "chapter_1", "title": "概述", "level": 1} responses = { "emphasis_chapter_1": "", "use_rag_chapter_1": "是", } handler = StubInteractionHandler(responses) state: Dict[str, Any] = { "needs_interaction": True, "current_chapter": chapter, "chapter_configs": {}, "interaction_handler": handler, "available_rag_sources": {}, } node = InteractWithUserNode() node.execute(state, NodeContext()) config = state["chapter_configs"][chapter["id"]] assert config["rag_enabled"] is False assert config["rag_sources"] == [] def test_interact_user_prompts_for_primary_when_no_tender_doc(): chapter = {"id": "chapter_1", "title": "概述", "level": 1} responses = { "emphasis_chapter_1": "", "use_rag_chapter_1": "是", "rag_primary_chapter_1": "global_kb", } handler = StubInteractionHandler(responses) state: Dict[str, Any] = { "needs_interaction": True, "current_chapter": chapter, "chapter_configs": {}, "interaction_handler": handler, "available_rag_sources": { "global_kb": { "id": "global_kb", "label": "知识库", "available": True, "default": True, } }, } node = InteractWithUserNode() node.execute(state, NodeContext()) config = state["chapter_configs"][chapter["id"]] assert config["rag_enabled"] is True assert config["rag_sources"] == ["global_kb"] class DummyTenderSearcher: def search(self, query: str, top_k: int): # pragma: no cover - simple stub return [ DocumentContextMatch( text="Tender snippet", section="章节一", score=0.9, source_type="paragraph", metadata={}, ) ] class DummyRagTool: def __init__(self): self.last_context: Dict[str, Any] | None = None def search(self, query: str, k: int): # pragma: no cover - simple stub return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}] def generate_content(self, task_id: str, context: Dict[str, Any]): self.last_context = context return "generated" def test_generate_content_merges_multiple_sources(): chapter = {"id": "chapter_1", "title": "概述", "level": 1} rag_tool = DummyRagTool() state: Dict[str, Any] = { "current_chapter": chapter, "chapter_queue": [chapter], "chapter_configs": { "chapter_1": { "rag_enabled": True, "rag_sources": ["tender_doc", "global_kb"], } }, "chapter_children_map": {}, "generated_contents": {}, "rag_tool": rag_tool, "tender_doc_searcher": DummyTenderSearcher(), "available_rag_sources": { "tender_doc": {"available": True}, "global_kb": {"available": True}, }, } node = GenerateContentNode() node.execute(state, NodeContext()) assert state["generated_contents"]["chapter_1"] == "generated" assert rag_tool.last_context is not None rag_context = rag_tool.last_context.get("rag_context") assert "【招标文件】" in rag_context assert "【知识库】" in rag_context