Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
174 lines
5.2 KiB
Python
174 lines
5.2 KiB
Python
"""Tests for multi-source RAG selection and aggregation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict
|
|
|
|
from bidmaster.nodes.base import NodeContext
|
|
from bidmaster.nodes.content.generate_content import GenerateContentNode
|
|
from bidmaster.nodes.content.interact_user import InteractWithUserNode
|
|
from bidmaster.utils.document_context import DocumentContextMatch
|
|
|
|
|
|
class StubInteractionHandler:
|
|
def __init__(self, responses: Dict[str, Any]):
|
|
self.responses = responses
|
|
|
|
def __call__(self, interaction_type: str, **kwargs):
|
|
key = kwargs.get("key")
|
|
if key and key in self.responses:
|
|
return self.responses[key]
|
|
return kwargs.get("default")
|
|
|
|
|
|
def test_interact_user_allows_multiple_rag_sources():
|
|
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
|
responses = {
|
|
"emphasis_chapter_1": "",
|
|
"use_rag_chapter_1": "是",
|
|
"rag_additional_confirm_chapter_1_0": True,
|
|
"rag_additional_choice_chapter_1_0": "global_kb",
|
|
}
|
|
handler = StubInteractionHandler(responses)
|
|
|
|
state: Dict[str, Any] = {
|
|
"needs_interaction": True,
|
|
"current_chapter": chapter,
|
|
"chapter_configs": {},
|
|
"interaction_handler": handler,
|
|
"available_rag_sources": {
|
|
"tender_doc": {
|
|
"id": "tender_doc",
|
|
"label": "招标文件",
|
|
"available": True,
|
|
"default": True,
|
|
},
|
|
"global_kb": {
|
|
"id": "global_kb",
|
|
"label": "知识库",
|
|
"available": True,
|
|
"default": False,
|
|
},
|
|
},
|
|
}
|
|
|
|
node = InteractWithUserNode()
|
|
node.execute(state, NodeContext())
|
|
|
|
config = state["chapter_configs"][chapter["id"]]
|
|
assert config["rag_enabled"] is True
|
|
assert config["rag_sources"] == ["tender_doc", "global_kb"]
|
|
|
|
|
|
def test_interact_user_disables_rag_without_sources():
|
|
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
|
responses = {
|
|
"emphasis_chapter_1": "",
|
|
"use_rag_chapter_1": "是",
|
|
}
|
|
handler = StubInteractionHandler(responses)
|
|
|
|
state: Dict[str, Any] = {
|
|
"needs_interaction": True,
|
|
"current_chapter": chapter,
|
|
"chapter_configs": {},
|
|
"interaction_handler": handler,
|
|
"available_rag_sources": {},
|
|
}
|
|
|
|
node = InteractWithUserNode()
|
|
node.execute(state, NodeContext())
|
|
|
|
config = state["chapter_configs"][chapter["id"]]
|
|
assert config["rag_enabled"] is False
|
|
assert config["rag_sources"] == []
|
|
|
|
|
|
def test_interact_user_prompts_for_primary_when_no_tender_doc():
|
|
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
|
responses = {
|
|
"emphasis_chapter_1": "",
|
|
"use_rag_chapter_1": "是",
|
|
"rag_primary_chapter_1": "global_kb",
|
|
}
|
|
handler = StubInteractionHandler(responses)
|
|
|
|
state: Dict[str, Any] = {
|
|
"needs_interaction": True,
|
|
"current_chapter": chapter,
|
|
"chapter_configs": {},
|
|
"interaction_handler": handler,
|
|
"available_rag_sources": {
|
|
"global_kb": {
|
|
"id": "global_kb",
|
|
"label": "知识库",
|
|
"available": True,
|
|
"default": True,
|
|
}
|
|
},
|
|
}
|
|
|
|
node = InteractWithUserNode()
|
|
node.execute(state, NodeContext())
|
|
|
|
config = state["chapter_configs"][chapter["id"]]
|
|
assert config["rag_enabled"] is True
|
|
assert config["rag_sources"] == ["global_kb"]
|
|
|
|
|
|
class DummyTenderSearcher:
|
|
def search(self, query: str, top_k: int): # pragma: no cover - simple stub
|
|
return [
|
|
DocumentContextMatch(
|
|
text="Tender snippet",
|
|
section="章节一",
|
|
score=0.9,
|
|
source_type="paragraph",
|
|
metadata={},
|
|
)
|
|
]
|
|
|
|
|
|
class DummyRagTool:
|
|
def __init__(self):
|
|
self.last_context: Dict[str, Any] | None = None
|
|
|
|
def search(self, query: str, k: int): # pragma: no cover - simple stub
|
|
return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}]
|
|
|
|
def generate_content(self, task_id: str, context: Dict[str, Any]):
|
|
self.last_context = context
|
|
return "generated"
|
|
|
|
|
|
def test_generate_content_merges_multiple_sources():
|
|
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
|
rag_tool = DummyRagTool()
|
|
state: Dict[str, Any] = {
|
|
"current_chapter": chapter,
|
|
"chapter_queue": [chapter],
|
|
"chapter_configs": {
|
|
"chapter_1": {
|
|
"rag_enabled": True,
|
|
"rag_sources": ["tender_doc", "global_kb"],
|
|
}
|
|
},
|
|
"chapter_children_map": {},
|
|
"generated_contents": {},
|
|
"rag_tool": rag_tool,
|
|
"tender_doc_searcher": DummyTenderSearcher(),
|
|
"available_rag_sources": {
|
|
"tender_doc": {"available": True},
|
|
"global_kb": {"available": True},
|
|
},
|
|
}
|
|
|
|
node = GenerateContentNode()
|
|
node.execute(state, NodeContext())
|
|
|
|
assert state["generated_contents"]["chapter_1"] == "generated"
|
|
assert rag_tool.last_context is not None
|
|
rag_context = rag_tool.last_context.get("rag_context")
|
|
assert "【招标文件】" in rag_context
|
|
assert "【知识库】" in rag_context
|