bidmaster-cli/tests/unit/test_rag_selection.py
sladro 08db5a66bb feat: default tender doc rag selection
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-11-19 16:24:56 +08:00

174 lines
5.2 KiB
Python

"""Tests for multi-source RAG selection and aggregation."""
from __future__ import annotations
from typing import Any, Dict
from bidmaster.nodes.base import NodeContext
from bidmaster.nodes.content.generate_content import GenerateContentNode
from bidmaster.nodes.content.interact_user import InteractWithUserNode
from bidmaster.utils.document_context import DocumentContextMatch
class StubInteractionHandler:
def __init__(self, responses: Dict[str, Any]):
self.responses = responses
def __call__(self, interaction_type: str, **kwargs):
key = kwargs.get("key")
if key and key in self.responses:
return self.responses[key]
return kwargs.get("default")
def test_interact_user_allows_multiple_rag_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_additional_confirm_chapter_1_0": True,
"rag_additional_choice_chapter_1_0": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"tender_doc": {
"id": "tender_doc",
"label": "招标文件",
"available": True,
"default": True,
},
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": False,
},
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["tender_doc", "global_kb"]
def test_interact_user_disables_rag_without_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is False
assert config["rag_sources"] == []
def test_interact_user_prompts_for_primary_when_no_tender_doc():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_primary_chapter_1": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": True,
}
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["global_kb"]
class DummyTenderSearcher:
def search(self, query: str, top_k: int): # pragma: no cover - simple stub
return [
DocumentContextMatch(
text="Tender snippet",
section="章节一",
score=0.9,
source_type="paragraph",
metadata={},
)
]
class DummyRagTool:
def __init__(self):
self.last_context: Dict[str, Any] | None = None
def search(self, query: str, k: int): # pragma: no cover - simple stub
return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}]
def generate_content(self, task_id: str, context: Dict[str, Any]):
self.last_context = context
return "generated"
def test_generate_content_merges_multiple_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool()
state: Dict[str, Any] = {
"current_chapter": chapter,
"chapter_queue": [chapter],
"chapter_configs": {
"chapter_1": {
"rag_enabled": True,
"rag_sources": ["tender_doc", "global_kb"],
}
},
"chapter_children_map": {},
"generated_contents": {},
"rag_tool": rag_tool,
"tender_doc_searcher": DummyTenderSearcher(),
"available_rag_sources": {
"tender_doc": {"available": True},
"global_kb": {"available": True},
},
}
node = GenerateContentNode()
node.execute(state, NodeContext())
assert state["generated_contents"]["chapter_1"] == "generated"
assert rag_tool.last_context is not None
rag_context = rag_tool.last_context.get("rag_context")
assert "【招标文件】" in rag_context
assert "【知识库】" in rag_context