feat: default tender doc rag selection
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
parent
c1292fcacc
commit
08db5a66bb
@ -3,11 +3,13 @@
|
||||
负责标书章节内容的自动填写(协调器模式)。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .single_chapter_agent import SingleChapterAgent
|
||||
from ..utils.document_context import DocumentContextBuilder, DocumentContextSearcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,7 +31,12 @@ class ContentWriterAgent:
|
||||
"""
|
||||
self.interaction_handler = interaction_handler
|
||||
|
||||
def write_content_sync(self, word_file: str) -> Dict[str, Any]:
|
||||
def write_content_sync(
|
||||
self,
|
||||
word_file: str,
|
||||
project_dir: Optional[str] = None,
|
||||
tender_file: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""同步执行内容填写
|
||||
|
||||
Args:
|
||||
@ -54,11 +61,14 @@ class ContentWriterAgent:
|
||||
|
||||
init_state = {
|
||||
"word_file": word_file,
|
||||
"project_dir": str(project_dir) if project_dir else None,
|
||||
}
|
||||
|
||||
state = init_node.execute(init_state, context)
|
||||
chapter_queue = state["chapter_queue"]
|
||||
|
||||
state = self._prepare_rag_sources(state, project_dir, tender_file)
|
||||
|
||||
logger.info(f"共 {len(chapter_queue)} 个章节待处理")
|
||||
for i, ch in enumerate(chapter_queue):
|
||||
logger.info(f" [{i+1}] {ch['id']} - {ch['title']} (level={ch['level']})")
|
||||
@ -127,4 +137,170 @@ class ContentWriterAgent:
|
||||
"""创建交互模式Agent"""
|
||||
from .interaction import InteractionHandler, InteractionMode
|
||||
handler = InteractionHandler(mode=InteractionMode.INTERACTIVE)
|
||||
return cls(handler)
|
||||
return cls(handler)
|
||||
|
||||
def _prepare_rag_sources(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
project_dir: Optional[str],
|
||||
tender_file: Optional[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""构建可用的RAG来源列表,并按需加载招标文件上下文。"""
|
||||
|
||||
tender_path = self._resolve_tender_source(project_dir, tender_file)
|
||||
tender_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
if tender_path:
|
||||
tender_info = self._build_tender_searcher(state, tender_path)
|
||||
|
||||
if not tender_info:
|
||||
prompted_path = self._prompt_for_tender_source()
|
||||
if prompted_path:
|
||||
tender_info = self._build_tender_searcher(state, prompted_path)
|
||||
|
||||
if tender_info:
|
||||
state["tender_doc_source_file"] = tender_info["file_path"]
|
||||
state["tender_doc_searcher"] = tender_info["searcher"]
|
||||
else:
|
||||
state.pop("tender_doc_source_file", None)
|
||||
state.pop("tender_doc_searcher", None)
|
||||
|
||||
state["available_rag_sources"] = self._compose_available_sources(state, tender_info)
|
||||
return state
|
||||
|
||||
def _resolve_tender_source(
|
||||
self,
|
||||
project_dir: Optional[str],
|
||||
tender_file: Optional[str],
|
||||
) -> Optional[str]:
|
||||
if tender_file:
|
||||
path = Path(tender_file)
|
||||
return str(path) if path.exists() else None
|
||||
|
||||
if not project_dir:
|
||||
return None
|
||||
|
||||
analysis_file = Path(project_dir) / "analysis_result.json"
|
||||
if not analysis_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(analysis_file, "r", encoding="utf-8") as fp:
|
||||
analysis_data = json.load(fp)
|
||||
except Exception:
|
||||
logger.warning("读取 analysis_result.json 失败,无法自动注入招标上下文")
|
||||
return None
|
||||
|
||||
source_file = analysis_data.get("source_file")
|
||||
if not source_file:
|
||||
return None
|
||||
|
||||
source_path = Path(source_file)
|
||||
if not source_path.exists():
|
||||
logger.warning("analysis_result.json 中的招标文件不存在: %s", source_file)
|
||||
return None
|
||||
|
||||
return str(source_path)
|
||||
|
||||
def _prompt_for_tender_source(self) -> Optional[str]:
|
||||
if not self.interaction_handler:
|
||||
return None
|
||||
|
||||
try:
|
||||
should_provide = self.interaction_handler(
|
||||
interaction_type="confirm",
|
||||
prompt="未找到招标文件上下文,是否手动选择一个招标文件供RAG使用?",
|
||||
default=False,
|
||||
key="provide_tender_doc",
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not should_provide:
|
||||
return None
|
||||
|
||||
path = self.interaction_handler(
|
||||
interaction_type="file_path",
|
||||
prompt="请输入招标文件(.docx)路径",
|
||||
default="",
|
||||
key="tender_doc_path",
|
||||
validation={"exists": True, "extensions": [".docx"]},
|
||||
)
|
||||
if not path:
|
||||
return None
|
||||
file_path = Path(path)
|
||||
return str(file_path) if file_path.exists() else None
|
||||
|
||||
def _build_tender_searcher(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
tender_path: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
builder = DocumentContextBuilder()
|
||||
document_context = builder.build(tender_path)
|
||||
if document_context.is_empty():
|
||||
logger.warning("招标文件未生成有效上下文: %s", tender_path)
|
||||
return None
|
||||
|
||||
searcher = DocumentContextSearcher(document_context)
|
||||
logger.info(
|
||||
"已载入招标文件上下文: %s (%s段)",
|
||||
Path(tender_path).name,
|
||||
len(document_context.chunks),
|
||||
)
|
||||
return {
|
||||
"file_path": tender_path,
|
||||
"chunk_count": len(document_context.chunks),
|
||||
"searcher": searcher,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("构建招标文件上下文失败: %s", exc)
|
||||
return None
|
||||
|
||||
def _compose_available_sources(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
tender_info: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
sources: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
if tender_info:
|
||||
file_name = Path(tender_info["file_path"]).name
|
||||
chunk_count = tender_info.get("chunk_count", 0)
|
||||
sources["tender_doc"] = {
|
||||
"id": "tender_doc",
|
||||
"label": f"招标文件({file_name},{chunk_count}段)",
|
||||
"available": True,
|
||||
"default": True,
|
||||
}
|
||||
else:
|
||||
sources["tender_doc"] = {
|
||||
"id": "tender_doc",
|
||||
"label": "招标文件(未配置)",
|
||||
"available": False,
|
||||
"default": False,
|
||||
}
|
||||
|
||||
rag_tool = state.get("rag_tool")
|
||||
kb_label = "知识库(未初始化)"
|
||||
kb_available = False
|
||||
kb_files = 0
|
||||
if rag_tool:
|
||||
try:
|
||||
stats = rag_tool.get_stats()
|
||||
kb_files = stats.get("total_files", 0)
|
||||
chunks = stats.get("total_chunks", 0)
|
||||
kb_label = f"知识库({kb_files}个文档,{chunks}段)"
|
||||
kb_available = True
|
||||
except Exception:
|
||||
kb_label = "知识库(统计失败)"
|
||||
|
||||
sources["global_kb"] = {
|
||||
"id": "global_kb",
|
||||
"label": kb_label,
|
||||
"available": kb_available,
|
||||
"default": not tender_info,
|
||||
}
|
||||
|
||||
return sources
|
||||
@ -36,7 +36,13 @@ def write():
|
||||
default="interactive",
|
||||
help="交互模式:interactive(交互)|silent(静默)",
|
||||
)
|
||||
def start(project_dir: str, mode: str):
|
||||
@click.option(
|
||||
"--tender-file",
|
||||
"-t",
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="用于RAG的招标文件(.docx)",
|
||||
)
|
||||
def start(project_dir: str, mode: str, tender_file: str | None = None):
|
||||
"""开始填写标书内容
|
||||
|
||||
PROJECT_DIR: 项目目录路径(默认当前目录)
|
||||
@ -58,7 +64,8 @@ def start(project_dir: str, mode: str):
|
||||
console.print(f"❌ 项目目录不存在: {project_dir}", style="red")
|
||||
return
|
||||
|
||||
console.print(f"📁 项目目录: {project_path.absolute()}", style="dim")
|
||||
project_abs_path = project_path.resolve()
|
||||
console.print(f"📁 项目目录: {project_abs_path}", style="dim")
|
||||
|
||||
# 自动查找Word文件
|
||||
word_file = _find_word_file(project_path)
|
||||
@ -68,6 +75,9 @@ def start(project_dir: str, mode: str):
|
||||
console.print(f"📄 找到Word文档: {Path(word_file).name}", style="green")
|
||||
console.print("📋 将直接从Word文档提取章节结构", style="dim")
|
||||
|
||||
# 规范化可选招标文件路径
|
||||
tender_file_path = str(Path(tender_file).resolve()) if tender_file else None
|
||||
|
||||
# 创建Agent
|
||||
if mode == "silent":
|
||||
agent = ContentWriterAgent.create_silent()
|
||||
@ -91,11 +101,19 @@ def start(project_dir: str, mode: str):
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("正在处理章节...", total=None)
|
||||
result = agent.write_content_sync(word_file=word_file)
|
||||
result = agent.write_content_sync(
|
||||
word_file=word_file,
|
||||
project_dir=str(project_abs_path),
|
||||
tender_file=tender_file_path,
|
||||
)
|
||||
progress.update(task, completed=True)
|
||||
else:
|
||||
# 交互模式:不使用Progress,避免覆盖用户输入
|
||||
result = agent.write_content_sync(word_file=word_file)
|
||||
result = agent.write_content_sync(
|
||||
word_file=word_file,
|
||||
project_dir=str(project_abs_path),
|
||||
tender_file=tender_file_path,
|
||||
)
|
||||
|
||||
# 显示结果
|
||||
if result.get("success"):
|
||||
|
||||
@ -83,6 +83,7 @@ class Settings(BaseSettings):
|
||||
# Word填充配置
|
||||
max_sub_chapter_level: int = Field(default=3, description="子章节最大层级")
|
||||
rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量")
|
||||
tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量")
|
||||
parent_context_length: int = Field(default=500, description="父章节上下文长度限制")
|
||||
|
||||
# Word格式配置
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..base import BaseNode, NodeContext
|
||||
from ...config.settings import get_settings
|
||||
@ -115,8 +115,9 @@ class GenerateContentNode(BaseNode):
|
||||
"""
|
||||
return {
|
||||
"emphasis": "",
|
||||
"rag_enabled": False,
|
||||
"rag_store": None
|
||||
"rag_enabled": True,
|
||||
"rag_store": "tender_doc",
|
||||
"rag_sources": ["tender_doc"],
|
||||
}
|
||||
|
||||
def _find_chapter(self, state: Dict[str, Any], chapter_id: str) -> Optional[Dict[str, Any]]:
|
||||
@ -193,9 +194,8 @@ class GenerateContentNode(BaseNode):
|
||||
prompt_spec = planner.build_prompt_spec(chapter)
|
||||
generation_context = dict(prompt_spec)
|
||||
|
||||
# 如果启用RAG,添加上下文信息
|
||||
# 如果启用RAG,按来源聚合上下文
|
||||
if config.get("rag_enabled"):
|
||||
# 检索相关内容
|
||||
query_fragments = [chapter["title"]]
|
||||
if prompt_spec.get("emphasis"):
|
||||
query_fragments.append(prompt_spec["emphasis"])
|
||||
@ -203,14 +203,22 @@ class GenerateContentNode(BaseNode):
|
||||
query_fragments.append(prompt_spec["requirements_summary"][:120])
|
||||
|
||||
query = " ".join(fragment for fragment in query_fragments if fragment)
|
||||
search_results = rag_tool.search(query, k=settings.rag_search_top_k)
|
||||
|
||||
if search_results:
|
||||
relevant_context = "\n\n".join([r["content"] for r in search_results])
|
||||
generation_context["rag_context"] = relevant_context
|
||||
logger.info(f"RAG检索到 {len(search_results)} 条相关内容")
|
||||
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
|
||||
config["rag_sources"] = rag_sources
|
||||
|
||||
aggregated_context: List[str] = []
|
||||
for source_id in rag_sources:
|
||||
if source_id == "tender_doc":
|
||||
aggregated_context.extend(self._search_tender_doc(state, query))
|
||||
elif source_id == "global_kb":
|
||||
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
|
||||
else:
|
||||
logger.warning("未知的RAG来源: %s", source_id)
|
||||
|
||||
if aggregated_context:
|
||||
generation_context["rag_context"] = "\n\n".join(aggregated_context)
|
||||
else:
|
||||
logger.warning("RAG未检索到相关内容")
|
||||
generation_context["rag_context"] = ""
|
||||
else:
|
||||
generation_context["rag_context"] = ""
|
||||
@ -221,4 +229,69 @@ class GenerateContentNode(BaseNode):
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"内容生成失败: {e}", exc_info=True)
|
||||
raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e
|
||||
raise ValueError(f"章节 {chapter['id']} - {chapter['title']} 内容生成失败") from e
|
||||
|
||||
def _default_rag_sources(self, state: Dict[str, Any]) -> List[str]:
|
||||
sources = state.get("available_rag_sources") or {}
|
||||
preferred: List[str] = []
|
||||
if sources.get("tender_doc", {}).get("available"):
|
||||
preferred.append("tender_doc")
|
||||
if sources.get("global_kb", {}).get("available"):
|
||||
preferred.append("global_kb")
|
||||
if preferred:
|
||||
return preferred
|
||||
if state.get("rag_tool"):
|
||||
return ["global_kb"]
|
||||
return []
|
||||
|
||||
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]:
|
||||
searcher = state.get("tender_doc_searcher")
|
||||
if not searcher:
|
||||
return []
|
||||
|
||||
top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
|
||||
try:
|
||||
matches = searcher.search(query, top_k)
|
||||
except Exception as exc:
|
||||
logger.warning("招标文件RAG检索失败: %s", exc)
|
||||
return []
|
||||
|
||||
contexts: List[str] = []
|
||||
for match in matches:
|
||||
section = getattr(match, "section", "招标文件") or "招标文件"
|
||||
text = getattr(match, "text", "")
|
||||
if not text:
|
||||
continue
|
||||
contexts.append(f"【招标文件】{section}\n{text}")
|
||||
|
||||
if contexts:
|
||||
logger.info("招标文件RAG检索到 %s 条内容", len(contexts))
|
||||
else:
|
||||
logger.warning("招标文件RAG未检索到相关内容")
|
||||
|
||||
return contexts
|
||||
|
||||
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]:
|
||||
if not rag_tool:
|
||||
return []
|
||||
|
||||
try:
|
||||
results = rag_tool.search(query, k=settings.rag_search_top_k)
|
||||
except Exception as exc:
|
||||
logger.warning("知识库RAG检索失败: %s", exc)
|
||||
return []
|
||||
|
||||
contexts: List[str] = []
|
||||
for result in results or []:
|
||||
content = result.get("content")
|
||||
if not content:
|
||||
continue
|
||||
source = (result.get("metadata") or {}).get("source") or "知识库片段"
|
||||
contexts.append(f"【知识库】{source}\n{content}")
|
||||
|
||||
if contexts:
|
||||
logger.info("知识库RAG检索到 %s 条内容", len(contexts))
|
||||
else:
|
||||
logger.warning("知识库RAG未检索到相关内容")
|
||||
|
||||
return contexts
|
||||
@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..base import BaseNode, NodeContext
|
||||
|
||||
@ -80,17 +80,24 @@ class InteractWithUserNode(BaseNode):
|
||||
)
|
||||
|
||||
use_rag = use_rag_response == "是"
|
||||
rag_store = None
|
||||
rag_sources: list[str] = []
|
||||
rag_store: Optional[str] = None
|
||||
|
||||
# 3. 如果使用RAG,选择知识库
|
||||
if use_rag:
|
||||
rag_store = self._select_rag_store(interaction_handler, chapter_id, state)
|
||||
rag_sources = self._select_rag_sources(interaction_handler, chapter_id, state)
|
||||
if rag_sources:
|
||||
rag_store = rag_sources[0]
|
||||
else:
|
||||
logger.warning("RAG已启用但没有可用的知识库来源,将自动禁用")
|
||||
use_rag = False
|
||||
|
||||
# 保存章节配置
|
||||
config = {
|
||||
"emphasis": emphasis.strip() if emphasis else "",
|
||||
"rag_enabled": use_rag,
|
||||
"rag_store": rag_store,
|
||||
"rag_sources": rag_sources if use_rag else [],
|
||||
}
|
||||
|
||||
state.setdefault("chapter_configs", {})[chapter_id] = config
|
||||
@ -102,45 +109,96 @@ class InteractWithUserNode(BaseNode):
|
||||
|
||||
return state
|
||||
|
||||
def _select_rag_store(self, interaction_handler, chapter_id: str, state: Dict[str, Any]) -> str:
|
||||
"""选择RAG知识库
|
||||
def _select_rag_sources(
|
||||
self,
|
||||
interaction_handler,
|
||||
chapter_id: str,
|
||||
state: Dict[str, Any],
|
||||
) -> list[str]:
|
||||
available_sources = self._get_available_sources(state)
|
||||
if not available_sources:
|
||||
return []
|
||||
|
||||
Args:
|
||||
interaction_handler: 交互处理器
|
||||
chapter_id: 章节ID
|
||||
state: 当前状态
|
||||
ordered_ids = self._order_sources(available_sources)
|
||||
selected: List[str] = []
|
||||
|
||||
Returns:
|
||||
知识库标识
|
||||
"""
|
||||
# 从state获取RAGTool实例(由InitConfigNode统一初始化)
|
||||
rag_tool = state.get("rag_tool")
|
||||
if not rag_tool:
|
||||
raise ValueError("RAGTool未初始化,请检查InitConfigNode配置")
|
||||
|
||||
try:
|
||||
stats = rag_tool.get_stats()
|
||||
total_chunks = stats.get("total_chunks", 0)
|
||||
total_files = stats.get("total_files", 0)
|
||||
|
||||
# 显示知识库统计信息
|
||||
store_info = f"默认知识库 ({total_files}个文档, {total_chunks}个文档块)"
|
||||
|
||||
choice = interaction_handler(
|
||||
if "tender_doc" in available_sources and available_sources["tender_doc"].get("available"):
|
||||
# 默认启用招标文件,不再询问主来源
|
||||
selected.append("tender_doc")
|
||||
remaining = [sid for sid in ordered_ids if sid != "tender_doc"]
|
||||
else:
|
||||
# 无招标文件时,仍需选择主来源
|
||||
default_source = ordered_ids[0]
|
||||
options = [
|
||||
(source_id, available_sources[source_id]["label"])
|
||||
for source_id in ordered_ids
|
||||
]
|
||||
primary_choice = interaction_handler(
|
||||
interaction_type="choice",
|
||||
prompt="选择知识库",
|
||||
options=[
|
||||
("default", store_info),
|
||||
("none", "不使用RAG")
|
||||
],
|
||||
default="default",
|
||||
key=f"rag_store_{chapter_id}",
|
||||
prompt="选择主要RAG来源",
|
||||
options=options,
|
||||
default=default_source,
|
||||
key=f"rag_primary_{chapter_id}",
|
||||
)
|
||||
if primary_choice not in available_sources:
|
||||
primary_choice = default_source
|
||||
selected.append(primary_choice)
|
||||
remaining = [sid for sid in ordered_ids if sid not in selected]
|
||||
|
||||
loop_index = 0
|
||||
while remaining:
|
||||
add_more = interaction_handler(
|
||||
interaction_type="confirm",
|
||||
prompt="是否追加其它RAG来源?",
|
||||
default=False,
|
||||
key=f"rag_additional_confirm_{chapter_id}_{loop_index}",
|
||||
)
|
||||
if not add_more:
|
||||
break
|
||||
|
||||
add_options = [
|
||||
(source_id, available_sources[source_id]["label"])
|
||||
for source_id in remaining
|
||||
]
|
||||
next_choice = interaction_handler(
|
||||
interaction_type="choice",
|
||||
prompt="选择需要追加的RAG来源",
|
||||
options=add_options,
|
||||
default=remaining[0],
|
||||
key=f"rag_additional_choice_{chapter_id}_{loop_index}",
|
||||
)
|
||||
|
||||
if choice == "default":
|
||||
return "default"
|
||||
return None
|
||||
if next_choice in remaining:
|
||||
selected.append(next_choice)
|
||||
remaining.remove(next_choice)
|
||||
loop_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库信息失败: {e}", exc_info=True)
|
||||
raise ValueError("获取知识库信息失败,无法选择RAG存储") from e
|
||||
return selected
|
||||
|
||||
def _get_available_sources(self, state: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
sources = state.get("available_rag_sources") or {}
|
||||
usable = {sid: meta for sid, meta in sources.items() if meta.get("available")}
|
||||
|
||||
if usable:
|
||||
return usable
|
||||
|
||||
if state.get("rag_tool"):
|
||||
return {
|
||||
"global_kb": {
|
||||
"id": "global_kb",
|
||||
"label": "知识库",
|
||||
"available": True,
|
||||
"default": True,
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
def _order_sources(self, sources: Dict[str, Dict[str, Any]]) -> list[str]:
|
||||
return sorted(
|
||||
sources.keys(),
|
||||
key=lambda sid: (
|
||||
not sources[sid].get("default", False),
|
||||
sources[sid].get("label", sid),
|
||||
),
|
||||
)
|
||||
173
tests/unit/test_rag_selection.py
Normal file
173
tests/unit/test_rag_selection.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Tests for multi-source RAG selection and aggregation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from bidmaster.nodes.base import NodeContext
|
||||
from bidmaster.nodes.content.generate_content import GenerateContentNode
|
||||
from bidmaster.nodes.content.interact_user import InteractWithUserNode
|
||||
from bidmaster.utils.document_context import DocumentContextMatch
|
||||
|
||||
|
||||
class StubInteractionHandler:
|
||||
def __init__(self, responses: Dict[str, Any]):
|
||||
self.responses = responses
|
||||
|
||||
def __call__(self, interaction_type: str, **kwargs):
|
||||
key = kwargs.get("key")
|
||||
if key and key in self.responses:
|
||||
return self.responses[key]
|
||||
return kwargs.get("default")
|
||||
|
||||
|
||||
def test_interact_user_allows_multiple_rag_sources():
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
responses = {
|
||||
"emphasis_chapter_1": "",
|
||||
"use_rag_chapter_1": "是",
|
||||
"rag_additional_confirm_chapter_1_0": True,
|
||||
"rag_additional_choice_chapter_1_0": "global_kb",
|
||||
}
|
||||
handler = StubInteractionHandler(responses)
|
||||
|
||||
state: Dict[str, Any] = {
|
||||
"needs_interaction": True,
|
||||
"current_chapter": chapter,
|
||||
"chapter_configs": {},
|
||||
"interaction_handler": handler,
|
||||
"available_rag_sources": {
|
||||
"tender_doc": {
|
||||
"id": "tender_doc",
|
||||
"label": "招标文件",
|
||||
"available": True,
|
||||
"default": True,
|
||||
},
|
||||
"global_kb": {
|
||||
"id": "global_kb",
|
||||
"label": "知识库",
|
||||
"available": True,
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
node = InteractWithUserNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
config = state["chapter_configs"][chapter["id"]]
|
||||
assert config["rag_enabled"] is True
|
||||
assert config["rag_sources"] == ["tender_doc", "global_kb"]
|
||||
|
||||
|
||||
def test_interact_user_disables_rag_without_sources():
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
responses = {
|
||||
"emphasis_chapter_1": "",
|
||||
"use_rag_chapter_1": "是",
|
||||
}
|
||||
handler = StubInteractionHandler(responses)
|
||||
|
||||
state: Dict[str, Any] = {
|
||||
"needs_interaction": True,
|
||||
"current_chapter": chapter,
|
||||
"chapter_configs": {},
|
||||
"interaction_handler": handler,
|
||||
"available_rag_sources": {},
|
||||
}
|
||||
|
||||
node = InteractWithUserNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
config = state["chapter_configs"][chapter["id"]]
|
||||
assert config["rag_enabled"] is False
|
||||
assert config["rag_sources"] == []
|
||||
|
||||
|
||||
def test_interact_user_prompts_for_primary_when_no_tender_doc():
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
responses = {
|
||||
"emphasis_chapter_1": "",
|
||||
"use_rag_chapter_1": "是",
|
||||
"rag_primary_chapter_1": "global_kb",
|
||||
}
|
||||
handler = StubInteractionHandler(responses)
|
||||
|
||||
state: Dict[str, Any] = {
|
||||
"needs_interaction": True,
|
||||
"current_chapter": chapter,
|
||||
"chapter_configs": {},
|
||||
"interaction_handler": handler,
|
||||
"available_rag_sources": {
|
||||
"global_kb": {
|
||||
"id": "global_kb",
|
||||
"label": "知识库",
|
||||
"available": True,
|
||||
"default": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
node = InteractWithUserNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
config = state["chapter_configs"][chapter["id"]]
|
||||
assert config["rag_enabled"] is True
|
||||
assert config["rag_sources"] == ["global_kb"]
|
||||
|
||||
|
||||
class DummyTenderSearcher:
|
||||
def search(self, query: str, top_k: int): # pragma: no cover - simple stub
|
||||
return [
|
||||
DocumentContextMatch(
|
||||
text="Tender snippet",
|
||||
section="章节一",
|
||||
score=0.9,
|
||||
source_type="paragraph",
|
||||
metadata={},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class DummyRagTool:
|
||||
def __init__(self):
|
||||
self.last_context: Dict[str, Any] | None = None
|
||||
|
||||
def search(self, query: str, k: int): # pragma: no cover - simple stub
|
||||
return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}]
|
||||
|
||||
def generate_content(self, task_id: str, context: Dict[str, Any]):
|
||||
self.last_context = context
|
||||
return "generated"
|
||||
|
||||
|
||||
def test_generate_content_merges_multiple_sources():
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
rag_tool = DummyRagTool()
|
||||
state: Dict[str, Any] = {
|
||||
"current_chapter": chapter,
|
||||
"chapter_queue": [chapter],
|
||||
"chapter_configs": {
|
||||
"chapter_1": {
|
||||
"rag_enabled": True,
|
||||
"rag_sources": ["tender_doc", "global_kb"],
|
||||
}
|
||||
},
|
||||
"chapter_children_map": {},
|
||||
"generated_contents": {},
|
||||
"rag_tool": rag_tool,
|
||||
"tender_doc_searcher": DummyTenderSearcher(),
|
||||
"available_rag_sources": {
|
||||
"tender_doc": {"available": True},
|
||||
"global_kb": {"available": True},
|
||||
},
|
||||
}
|
||||
|
||||
node = GenerateContentNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
assert state["generated_contents"]["chapter_1"] == "generated"
|
||||
assert rag_tool.last_context is not None
|
||||
rag_context = rag_tool.last_context.get("rag_context")
|
||||
assert "【招标文件】" in rag_context
|
||||
assert "【知识库】" in rag_context
|
||||
Loading…
Reference in New Issue
Block a user