refactor: automate content filling and default rag

Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
This commit is contained in:
sladro 2025-12-05 10:19:50 +08:00
parent 08db5a66bb
commit 61026cd1b6
7 changed files with 90 additions and 408 deletions

View File

@ -10,7 +10,6 @@ from ..base import AgentBuilder
from ...nodes.content import (
InitConfigNode,
PrepareChapterNode,
InteractWithUserNode,
GenerateContentNode,
SaveToWordNode,
CollectResultsNode,
@ -19,19 +18,6 @@ from ...nodes.content import (
logger = logging.getLogger(__name__)
def need_user_interaction(state: Dict[str, Any]) -> str:
"""判断是否需要用户交互
Args:
state: 当前状态
Returns:
路由键"interact""generate"
"""
needs_interaction = state.get("needs_interaction", False)
return "interact" if needs_interaction else "generate"
def should_continue_loop(state: Dict[str, Any]) -> str:
"""判断是否继续循环处理章节
@ -59,7 +45,6 @@ class ContentWriterAgentBuilder(AgentBuilder):
1. 初始化配置并生成章节队列
2. 循环处理每个章节
- 准备章节
- 用户交互仅1级标题
- 生成内容
- 保存到Word
3. 收集结果
@ -80,7 +65,6 @@ class ContentWriterAgentBuilder(AgentBuilder):
# 添加所有节点
builder.add_node(InitConfigNode()) \
.add_node(PrepareChapterNode()) \
.add_node(InteractWithUserNode()) \
.add_node(GenerateContentNode()) \
.add_node(SaveToWordNode()) \
.add_node(CollectResultsNode())
@ -99,15 +83,8 @@ class ContentWriterAgentBuilder(AgentBuilder):
# init_config → prepare_chapter
self.add_edge("init_config", "prepare_chapter")
# prepare_chapter → 条件分支(是否需要交互)
self.add_conditional_edge(
"prepare_chapter",
need_user_interaction,
{"interact": "interact_user", "generate": "generate_content"},
)
# interact_user → generate_content
self.add_edge("interact_user", "generate_content")
# prepare_chapter → generate_content
self.add_edge("prepare_chapter", "generate_content")
# generate_content → save_to_word
self.add_edge("generate_content", "save_to_word")

View File

@ -8,7 +8,6 @@ from typing import Any, Dict
from .base import BaseAgent, AgentBuilder
from ..nodes.content import (
InteractWithUserNode,
GenerateContentNode,
SaveToWordNode,
)
@ -16,18 +15,6 @@ from ..nodes.content import (
logger = logging.getLogger(__name__)
def need_user_interaction(state: Dict[str, Any]) -> str:
"""判断是否需要用户交互"""
needs_interaction = state.get("needs_interaction", False)
return "interact" if needs_interaction else "generate"
def route_start(state: Dict[str, Any]) -> str:
"""入口路由:判断是否需要交互"""
needs_interaction = state.get("needs_interaction", False)
return "interact" if needs_interaction else "generate"
class SingleChapterAgentBuilder(AgentBuilder):
"""单章节Agent构建器"""
@ -37,15 +24,13 @@ class SingleChapterAgentBuilder(AgentBuilder):
builder = cls(interaction_handler)
# 添加节点去掉PrepareChapterNode因为外层已准备好
builder.add_node(InteractWithUserNode()) \
.add_node(GenerateContentNode()) \
builder.add_node(GenerateContentNode()) \
.add_node(SaveToWordNode())
# 设置入口interact_user内部会判断是否真的需要交互
builder.set_entry("interact_user")
# 设置入口
builder.set_entry("generate_content")
# 配置流程:线性流程
builder.add_edge("interact_user", "generate_content")
builder.add_edge("generate_content", "save_to_word")
builder.add_edge("save_to_word", "END")
@ -73,7 +58,6 @@ class SingleChapterAgent(BaseAgent):
chapter_state = {
**state, # 继承所有字段包括rag_tool、expanded_configs等
"current_chapter": chapter, # 覆盖当前章节
"needs_interaction": chapter["level"] == 1, # 覆盖交互标志
}
# 执行

View File

@ -5,7 +5,6 @@
from .init_config import InitConfigNode
from .prepare_chapter import PrepareChapterNode
from .interact_user import InteractWithUserNode
from .generate_content import GenerateContentNode
from .save_to_word import SaveToWordNode
from .collect_results import CollectResultsNode
@ -13,7 +12,6 @@ from .collect_results import CollectResultsNode
__all__ = [
"InitConfigNode",
"PrepareChapterNode",
"InteractWithUserNode",
"GenerateContentNode",
"SaveToWordNode",
"CollectResultsNode",

View File

@ -115,9 +115,7 @@ class GenerateContentNode(BaseNode):
"""
return {
"emphasis": "",
"rag_enabled": True,
"rag_store": "tender_doc",
"rag_sources": ["tender_doc"],
"rag_sources": [],
}
def _find_chapter(self, state: Dict[str, Any], chapter_id: str) -> Optional[Dict[str, Any]]:
@ -193,35 +191,36 @@ class GenerateContentNode(BaseNode):
prompt_spec = planner.build_prompt_spec(chapter)
generation_context = dict(prompt_spec)
outline_overview = prompt_spec.get("outline_overview") or state.get("outline_summary") or ""
relation_hints = prompt_spec.get("relation_hints") or ""
generation_context["relation_hints"] = relation_hints
generation_context["outline_overview"] = outline_overview
# 如果启用RAG按来源聚合上下文
if config.get("rag_enabled"):
query_fragments = [chapter["title"]]
if prompt_spec.get("emphasis"):
query_fragments.append(prompt_spec["emphasis"])
if prompt_spec.get("requirements_summary"):
query_fragments.append(prompt_spec["requirements_summary"][:120])
query_fragments = [chapter["title"]]
if prompt_spec.get("emphasis"):
query_fragments.append(prompt_spec["emphasis"])
if prompt_spec.get("requirements_summary"):
query_fragments.append(prompt_spec["requirements_summary"][:120])
if relation_hints:
query_fragments.append(relation_hints)
if outline_overview:
query_fragments.append(outline_overview[:200])
query = " ".join(fragment for fragment in query_fragments if fragment)
query = " ".join(fragment for fragment in query_fragments if fragment)
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
config["rag_sources"] = rag_sources
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
config["rag_sources"] = rag_sources
aggregated_context: List[str] = []
for source_id in rag_sources:
if source_id == "tender_doc":
aggregated_context.extend(self._search_tender_doc(state, query))
elif source_id == "global_kb":
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
else:
logger.warning("未知的RAG来源: %s", source_id)
if aggregated_context:
generation_context["rag_context"] = "\n\n".join(aggregated_context)
aggregated_context: List[str] = []
for source_id in rag_sources:
if source_id == "tender_doc":
aggregated_context.extend(self._search_tender_doc(state, query))
elif source_id == "global_kb":
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
else:
generation_context["rag_context"] = ""
else:
generation_context["rag_context"] = ""
logger.warning("未知的RAG来源: %s", source_id)
generation_context["rag_context"] = "\n\n".join(aggregated_context) if aggregated_context else ""
# 调用生成方法
try:

View File

@ -1,204 +0,0 @@
"""用户交互节点
与用户交互获取章节填写要求和RAG配置
"""
import logging
from typing import Any, Dict, List, Optional
from ..base import BaseNode, NodeContext
logger = logging.getLogger(__name__)
class InteractWithUserNode(BaseNode):
"""用户交互节点
职责
1. 询问用户该章节需要强调的内容
2. 询问是否使用RAG知识库
3. 如果使用RAG选择具体的知识库
4. 保存配置供后续子章节继承
"""
@property
def name(self) -> str:
return "interact_user"
@property
def description(self) -> str:
return "与用户交互获取章节配置"
def execute(self, state: Dict[str, Any], context: NodeContext) -> Dict[str, Any]:
"""执行用户交互
Args:
state: 当前状态
context: 执行上下文
Returns:
更新后的状态
"""
chapter = state.get("current_chapter", {})
chapter_id = chapter.get("id")
chapter_title = chapter.get("title")
if not chapter_id:
raise ValueError("当前章节信息缺失")
# 检查是否需要交互
needs_interaction = state.get("needs_interaction", False)
if not needs_interaction:
# 不需要交互,直接返回
return state
# 获取交互处理器
interaction_handler = state.get("interaction_handler")
if not interaction_handler:
# 无交互处理器时立即失败,暴露配置问题
raise ValueError("交互处理器未配置,无法获取章节配置")
logger.info(f"开始与用户交互,获取章节配置: {chapter_id} - {chapter_title}")
# 1. 询问需要强调的内容
emphasis = interaction_handler(
interaction_type="text",
prompt=f"📝 正在处理章节:{chapter_id}. {chapter_title}\n"
f" 是否有需要特别强调的内容?(直接回车跳过)",
default="",
key=f"emphasis_{chapter_id}",
)
# 2. 询问是否使用RAG
use_rag_response = interaction_handler(
interaction_type="choice",
prompt="是否使用RAG知识库辅助生成内容",
options=["", ""],
default="",
key=f"use_rag_{chapter_id}",
)
use_rag = use_rag_response == ""
rag_sources: list[str] = []
rag_store: Optional[str] = None
# 3. 如果使用RAG选择知识库
if use_rag:
rag_sources = self._select_rag_sources(interaction_handler, chapter_id, state)
if rag_sources:
rag_store = rag_sources[0]
else:
logger.warning("RAG已启用但没有可用的知识库来源将自动禁用")
use_rag = False
# 保存章节配置
config = {
"emphasis": emphasis.strip() if emphasis else "",
"rag_enabled": use_rag,
"rag_store": rag_store,
"rag_sources": rag_sources if use_rag else [],
}
state.setdefault("chapter_configs", {})[chapter_id] = config
logger.info(
f"章节配置已保存: 强调内容={'' if config['emphasis'] else ''}, "
f"RAG={'启用' if use_rag else '禁用'}"
)
return state
def _select_rag_sources(
self,
interaction_handler,
chapter_id: str,
state: Dict[str, Any],
) -> list[str]:
available_sources = self._get_available_sources(state)
if not available_sources:
return []
ordered_ids = self._order_sources(available_sources)
selected: List[str] = []
if "tender_doc" in available_sources and available_sources["tender_doc"].get("available"):
# 默认启用招标文件,不再询问主来源
selected.append("tender_doc")
remaining = [sid for sid in ordered_ids if sid != "tender_doc"]
else:
# 无招标文件时,仍需选择主来源
default_source = ordered_ids[0]
options = [
(source_id, available_sources[source_id]["label"])
for source_id in ordered_ids
]
primary_choice = interaction_handler(
interaction_type="choice",
prompt="选择主要RAG来源",
options=options,
default=default_source,
key=f"rag_primary_{chapter_id}",
)
if primary_choice not in available_sources:
primary_choice = default_source
selected.append(primary_choice)
remaining = [sid for sid in ordered_ids if sid not in selected]
loop_index = 0
while remaining:
add_more = interaction_handler(
interaction_type="confirm",
prompt="是否追加其它RAG来源",
default=False,
key=f"rag_additional_confirm_{chapter_id}_{loop_index}",
)
if not add_more:
break
add_options = [
(source_id, available_sources[source_id]["label"])
for source_id in remaining
]
next_choice = interaction_handler(
interaction_type="choice",
prompt="选择需要追加的RAG来源",
options=add_options,
default=remaining[0],
key=f"rag_additional_choice_{chapter_id}_{loop_index}",
)
if next_choice in remaining:
selected.append(next_choice)
remaining.remove(next_choice)
loop_index += 1
return selected
def _get_available_sources(self, state: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
sources = state.get("available_rag_sources") or {}
usable = {sid: meta for sid, meta in sources.items() if meta.get("available")}
if usable:
return usable
if state.get("rag_tool"):
return {
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": True,
}
}
return {}
def _order_sources(self, sources: Dict[str, Dict[str, Any]]) -> list[str]:
return sorted(
sources.keys(),
key=lambda sid: (
not sources[sid].get("default", False),
sources[sid].get("label", sid),
),
)

View File

@ -52,9 +52,4 @@ class PrepareChapterNode(BaseNode):
f"{current_chapter['id']} - {current_chapter['title']}"
)
# 判断是否需要用户交互仅1级标题
needs_interaction = current_chapter["level"] == 1
return self._update_state(
state, current_chapter=current_chapter, needs_interaction=needs_interaction
)
return self._update_state(state, current_chapter=current_chapter)

View File

@ -6,121 +6,19 @@ from typing import Any, Dict
from bidmaster.nodes.base import NodeContext
from bidmaster.nodes.content.generate_content import GenerateContentNode
from bidmaster.nodes.content.interact_user import InteractWithUserNode
from bidmaster.utils.document_context import DocumentContextMatch
class StubInteractionHandler:
def __init__(self, responses: Dict[str, Any]):
self.responses = responses
def __call__(self, interaction_type: str, **kwargs):
key = kwargs.get("key")
if key and key in self.responses:
return self.responses[key]
return kwargs.get("default")
def test_interact_user_allows_multiple_rag_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_additional_confirm_chapter_1_0": True,
"rag_additional_choice_chapter_1_0": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"tender_doc": {
"id": "tender_doc",
"label": "招标文件",
"available": True,
"default": True,
},
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": False,
},
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["tender_doc", "global_kb"]
def test_interact_user_disables_rag_without_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is False
assert config["rag_sources"] == []
def test_interact_user_prompts_for_primary_when_no_tender_doc():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
responses = {
"emphasis_chapter_1": "",
"use_rag_chapter_1": "",
"rag_primary_chapter_1": "global_kb",
}
handler = StubInteractionHandler(responses)
state: Dict[str, Any] = {
"needs_interaction": True,
"current_chapter": chapter,
"chapter_configs": {},
"interaction_handler": handler,
"available_rag_sources": {
"global_kb": {
"id": "global_kb",
"label": "知识库",
"available": True,
"default": True,
}
},
}
node = InteractWithUserNode()
node.execute(state, NodeContext())
config = state["chapter_configs"][chapter["id"]]
assert config["rag_enabled"] is True
assert config["rag_sources"] == ["global_kb"]
class DummyTenderSearcher:
def __init__(self, label: str = "Tender snippet"):
self.label = label
self.last_query: str | None = None
def search(self, query: str, top_k: int): # pragma: no cover - simple stub
self.last_query = query
return [
DocumentContextMatch(
text="Tender snippet",
text=self.label,
section="章节一",
score=0.9,
source_type="paragraph",
@ -130,39 +28,41 @@ class DummyTenderSearcher:
class DummyRagTool:
def __init__(self):
def __init__(self, kb_content: str = "KB snippet"):
self.last_context: Dict[str, Any] | None = None
self.kb_content = kb_content
self.search_queries: list[str] = []
def search(self, query: str, k: int): # pragma: no cover - simple stub
return [{"content": "KB snippet", "metadata": {"source": "kb.docx"}}]
self.search_queries.append(query)
return [{"content": self.kb_content, "metadata": {"source": "kb.docx"}}]
def generate_content(self, task_id: str, context: Dict[str, Any]):
self.last_context = context
return "generated"
def test_generate_content_merges_multiple_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool()
state: Dict[str, Any] = {
def _base_state(chapter: Dict[str, Any], rag_tool: DummyRagTool) -> Dict[str, Any]:
return {
"current_chapter": chapter,
"chapter_queue": [chapter],
"chapter_configs": {
"chapter_1": {
"rag_enabled": True,
"rag_sources": ["tender_doc", "global_kb"],
}
},
"chapter_configs": {},
"chapter_children_map": {},
"generated_contents": {},
"rag_tool": rag_tool,
"tender_doc_searcher": DummyTenderSearcher(),
"available_rag_sources": {
"tender_doc": {"available": True},
"global_kb": {"available": True},
},
}
def test_generate_content_auto_uses_all_available_sources():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool()
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {
"tender_doc": {"available": True},
"global_kb": {"available": True},
}
state["tender_doc_searcher"] = DummyTenderSearcher()
node = GenerateContentNode()
node.execute(state, NodeContext())
@ -171,3 +71,36 @@ def test_generate_content_merges_multiple_sources():
rag_context = rag_tool.last_context.get("rag_context")
assert "【招标文件】" in rag_context
assert "【知识库】" in rag_context
def test_generate_content_falls_back_when_tender_doc_missing():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool(kb_content="Only KB")
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {
"tender_doc": {"available": False},
"global_kb": {"available": True},
}
node = GenerateContentNode()
node.execute(state, NodeContext())
rag_context = rag_tool.last_context.get("rag_context")
assert "Only KB" in rag_context
assert "【招标文件】" not in rag_context
def test_generate_content_fallbacks_to_global_when_sources_unavailable():
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool(kb_content="fallback")
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {
"tender_doc": {"available": False},
"global_kb": {"available": False},
}
node = GenerateContentNode()
node.execute(state, NodeContext())
# 即使标记不可用,也会退回全局知识库,保证填充不中断
assert "fallback" in rag_tool.last_context.get("rag_context")