From 3d12f9d065ee4e30bcf6f1508ec9e04b2f16d929 Mon Sep 17 00:00:00 2001 From: sladro Date: Fri, 19 Dec 2025 13:35:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20add=20rag=20context=20and=20filters,?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=86=E5=AF=B9rag=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E7=9A=84=E5=8A=A0=E5=BC=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 14 +- config/prompts.yaml | 2 +- src/bidmaster/config/settings.py | 18 +++ .../nodes/content/generate_content.py | 117 ++++++++++++--- src/bidmaster/tools/word.py | 4 +- src/bidmaster/tools/word_formatter.py | 1 - src/bidmaster/utils/outline_context.py | 111 ++++++++++++++ src/bidmaster/utils/rag_context.py | 91 +++++++++++ tests/unit/test_rag_context_utils.py | 27 ++++ tests/unit/test_rag_selection_filters.py | 142 ++++++++++++++++++ 10 files changed, 503 insertions(+), 24 deletions(-) create mode 100644 src/bidmaster/utils/outline_context.py create mode 100644 src/bidmaster/utils/rag_context.py create mode 100644 tests/unit/test_rag_context_utils.py create mode 100644 tests/unit/test_rag_selection_filters.py diff --git a/config/config.yaml b/config/config.yaml index 7b51bc2..87d3e73 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -43,4 +43,16 @@ interaction: # 日志设置 logging: level: INFO - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" \ No newline at end of file + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +# RAG 检索与上下文控制(注意:这些必须是顶层字段,才能映射到 Settings) +rag_search_top_k: 3 +rag_search_top_k_max: 20 +rag_similarity_threshold: 0.0 + +tender_doc_search_top_k: 3 +tender_doc_search_top_k_max: 20 +tender_doc_similarity_threshold: 0.2 + +rag_context_dedup: true +rag_context_token_budget: 1200 \ No newline at end of file diff --git a/config/prompts.yaml b/config/prompts.yaml index 1301ad9..627b6c4 100644 --- a/config/prompts.yaml +++ b/config/prompts.yaml @@ -297,7 +297,7 @@ content_prompts: 要求: 1. 内容专业、详实,符合招标文件要求 - 2. 突出技术优势和实施能力 + 2. 突出技术优势和实施能力,在技术上,流程上一定要保证各个环节章节的一致性 3. 严格遵守写作模式与篇幅要求: 有子标题仅写父级统领概要(不展开子标题正文),无子标题输出完整正文 4. 严禁新增任何章/节级标题或“商务条款、技术偏差、响应情况”等模板段,如需结构化仅使用普通段落或加粗语句 5. 开头不得出现“经认真研读招标文件要求”“偏差说明如下”等跨章节套话,内容必须围绕《{title}》本身展开 diff --git a/src/bidmaster/config/settings.py b/src/bidmaster/config/settings.py index da6462f..1c5e39b 100644 --- a/src/bidmaster/config/settings.py +++ b/src/bidmaster/config/settings.py @@ -83,7 +83,25 @@ class Settings(BaseSettings): # Word填充配置 max_sub_chapter_level: int = Field(default=3, description="子章节最大层级") rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量") + rag_search_top_k_max: int = Field( + default=0, + description="RAG检索最大候选数量(0表示使用rag_search_top_k)", + ) + rag_similarity_threshold: float = Field( + default=0.0, + description="知识库RAG相似度阈值(<=0表示不启用)", + ) tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量") + tender_doc_search_top_k_max: int = Field( + default=0, + description="招标文档检索最大候选数量(0表示使用tender_doc_search_top_k)", + ) + tender_doc_similarity_threshold: float = Field( + default=0.0, + description="招标文档RAG相似度阈值(<=0表示不启用)", + ) + rag_context_dedup: bool = Field(default=False, description="是否对RAG片段按文本去重") + rag_context_token_budget: int = Field(default=0, description="RAG上下文token预算(0表示不启用)") parent_context_length: int = Field(default=500, description="父章节上下文长度限制") # Word格式配置 diff --git a/src/bidmaster/nodes/content/generate_content.py b/src/bidmaster/nodes/content/generate_content.py index e471c75..ec29054 100644 --- a/src/bidmaster/nodes/content/generate_content.py +++ b/src/bidmaster/nodes/content/generate_content.py @@ -4,16 +4,25 @@ """ import logging +from dataclasses import dataclass from typing import Any, Dict, List, Optional from ..base import BaseNode, NodeContext from ...config.settings import get_settings +from ...utils.rag_context import fit_texts_to_token_budget, normalize_text_for_dedup from ...utils.prompt_planner import PromptPlanner logger = logging.getLogger(__name__) settings = get_settings() +@dataclass(frozen=True) +class _RagCandidate: + display_text: str + dedup_text: str + score: float + + class GenerateContentNode(BaseNode): """内容生成节点 @@ -210,16 +219,17 @@ class GenerateContentNode(BaseNode): rag_sources = config.get("rag_sources") or self._default_rag_sources(state) config["rag_sources"] = rag_sources - aggregated_context: List[str] = [] + aggregated_candidates: List[_RagCandidate] = [] for source_id in rag_sources: if source_id == "tender_doc": - aggregated_context.extend(self._search_tender_doc(state, query)) + aggregated_candidates.extend(self._search_tender_doc(state, query)) elif source_id == "global_kb": - aggregated_context.extend(self._search_global_kb(state, rag_tool, query)) + aggregated_candidates.extend(self._search_global_kb(state, rag_tool, query)) else: logger.warning("未知的RAG来源: %s", source_id) - generation_context["rag_context"] = "\n\n".join(aggregated_context) if aggregated_context else "" + rag_texts = self._finalize_rag_context(aggregated_candidates) + generation_context["rag_context"] = "\n\n".join(rag_texts) if rag_texts else "" # 调用生成方法 try: @@ -242,54 +252,123 @@ class GenerateContentNode(BaseNode): return ["global_kb"] return [] - def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]: + def _finalize_rag_context(self, candidates: List[_RagCandidate]) -> List[str]: + if not candidates: + return [] + + filtered = candidates + if getattr(settings, "rag_context_dedup", False): + seen: set[str] = set() + deduped: List[_RagCandidate] = [] + for item in filtered: + key = normalize_text_for_dedup(item.dedup_text) + if not key or key in seen: + continue + seen.add(key) + deduped.append(item) + filtered = deduped + + texts = [item.display_text for item in filtered] + budget = int(getattr(settings, "rag_context_token_budget", 0) or 0) + texts = fit_texts_to_token_budget(texts, budget) + + if getattr(settings, "rag_context_dedup", False) or budget > 0: + logger.info( + "RAG上下文最终选取 %s 条片段(候选 %s 条,去重=%s,token预算=%s)", + len(texts), + len(candidates), + bool(getattr(settings, "rag_context_dedup", False)), + budget, + ) + + return texts + + def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[_RagCandidate]: searcher = state.get("tender_doc_searcher") if not searcher: return [] - top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k) + base_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k) + max_k = getattr(settings, "tender_doc_search_top_k_max", 0) or base_k + max_k = max(int(max_k), 1) + threshold = float(getattr(settings, "tender_doc_similarity_threshold", 0.0) or 0.0) try: - matches = searcher.search(query, top_k) + matches = searcher.search(query, max_k) except Exception as exc: logger.warning("招标文件RAG检索失败: %s", exc) return [] - contexts: List[str] = [] + selected: List[_RagCandidate] = [] for match in matches: section = getattr(match, "section", "招标文件") or "招标文件" text = getattr(match, "text", "") + score = float(getattr(match, "score", 0.0) or 0.0) if not text: continue - contexts.append(f"【招标文件】{section}\n{text}") + if threshold > 0 and score < threshold: + continue + selected.append( + _RagCandidate( + display_text=f"【招标文件】{section}\n{text}", + dedup_text=text, + score=score, + ) + ) - if contexts: - logger.info("招标文件RAG检索到 %s 条内容", len(contexts)) + if selected: + logger.info( + "招标文件RAG选取 %s 条内容(候选 %s 条,max=%s,阈值=%s)", + len(selected), + len(matches or []), + max_k, + threshold, + ) else: logger.warning("招标文件RAG未检索到相关内容") - return contexts + return selected - def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]: + def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[_RagCandidate]: if not rag_tool: return [] + base_k = int(getattr(settings, "rag_search_top_k", 3) or 3) + max_k = getattr(settings, "rag_search_top_k_max", 0) or base_k + max_k = max(int(max_k), 1) + threshold = float(getattr(settings, "rag_similarity_threshold", 0.0) or 0.0) + try: - results = rag_tool.search(query, k=settings.rag_search_top_k) + results = rag_tool.search(query, k=max_k) except Exception as exc: logger.warning("知识库RAG检索失败: %s", exc) return [] - contexts: List[str] = [] + selected: List[_RagCandidate] = [] for result in results or []: content = result.get("content") if not content: continue source = (result.get("metadata") or {}).get("source") or "知识库片段" - contexts.append(f"【知识库】{source}\n{content}") + score = float(result.get("score") or 0.0) + if threshold > 0 and score < threshold: + continue + selected.append( + _RagCandidate( + display_text=f"【知识库】{source}\n{content}", + dedup_text=content, + score=score, + ) + ) - if contexts: - logger.info("知识库RAG检索到 %s 条内容", len(contexts)) + if selected: + logger.info( + "知识库RAG选取 %s 条内容(候选 %s 条,max=%s,阈值=%s)", + len(selected), + len(results or []), + max_k, + threshold, + ) else: logger.warning("知识库RAG未检索到相关内容") - return contexts \ No newline at end of file + return selected \ No newline at end of file diff --git a/src/bidmaster/tools/word.py b/src/bidmaster/tools/word.py index 95203d6..8139812 100644 --- a/src/bidmaster/tools/word.py +++ b/src/bidmaster/tools/word.py @@ -107,7 +107,7 @@ class WordProcessor: # 使用Word的标题样式 if chapter.level <= MAX_HEADING_LEVEL: - heading = doc.add_heading(title_text, level=chapter.level) + doc.add_heading(title_text, level=chapter.level) else: # 超过最大层级用普通段落加粗 para = doc.add_paragraph() @@ -119,7 +119,7 @@ class WordProcessor: # 为有内容的章节添加占位符 if chapter.template_placeholder: - content_para = doc.add_paragraph(f"\n{chapter.template_placeholder}\n") + doc.add_paragraph(f"\n{chapter.template_placeholder}\n") # 添加写作指导 if chapter.score and chapter.score > SCORE_THRESHOLD: diff --git a/src/bidmaster/tools/word_formatter.py b/src/bidmaster/tools/word_formatter.py index ef5dbdc..2838538 100644 --- a/src/bidmaster/tools/word_formatter.py +++ b/src/bidmaster/tools/word_formatter.py @@ -67,7 +67,6 @@ class WordFormatter: for pattern, style in patterns: match = re.match(f'^{pattern}$', part.strip()) if match: - number_part = match.group(1) suffix = match.group(2) if match.lastindex >= 2 else "" format_dict[level] = { diff --git a/src/bidmaster/utils/outline_context.py b/src/bidmaster/utils/outline_context.py new file mode 100644 index 0000000..d90c376 --- /dev/null +++ b/src/bidmaster/utils/outline_context.py @@ -0,0 +1,111 @@ +"""生成章节目录上下文信息。""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + + +class OutlineContextBuilder: + """根据章节列表构建目录摘要与章节上下文。""" + + def __init__(self, chapters: Optional[List[Dict[str, Any]]]): + self.chapters = chapters or [] + self.chapter_map = {ch["id"]: ch for ch in self.chapters} + self.children_map = self._build_children_map() + + def build(self) -> Dict[str, Any]: + outline_summary = self._build_outline_summary() + contexts: Dict[str, Dict[str, Any]] = {} + for index, chapter in enumerate(self.chapters): + contexts[chapter["id"]] = self._build_chapter_context(chapter, index) + return { + "outline_summary": outline_summary, + "chapter_contexts": contexts, + } + + def _build_children_map(self) -> Dict[str, List[Dict[str, Any]]]: + result: Dict[str, List[Dict[str, Any]]] = {} + for chapter in self.chapters: + parent_id = chapter.get("parent_id") + if not parent_id: + continue + result.setdefault(parent_id, []).append(chapter) + for siblings in result.values(): + siblings.sort(key=lambda item: item.get("order_index", 0)) + return result + + def _build_outline_summary(self, limit: int = 1600) -> str: + lines: List[str] = [] + for chapter in self.chapters: + level = max(chapter.get("level", 1) - 1, 0) + indent = " " * level + heading = chapter.get("heading_number") + label = chapter.get("title") or chapter.get("raw_heading") or chapter["id"] + if heading: + label = f"{heading} {label}" + lines.append(f"{indent}- {label}") + summary = "\n".join(lines) + return self._truncate(summary, limit) + + def _build_chapter_context(self, chapter: Dict[str, Any], index: int) -> Dict[str, Any]: + parent_chain = self._collect_parent_titles(chapter) + siblings = [c for c in self.children_map.get(chapter.get("parent_id"), []) if c["id"] != chapter["id"]] + prev_chapter = self.chapters[index - 1] if index > 0 else None + next_chapter = self.chapters[index + 1] if index + 1 < len(self.chapters) else None + + context = { + "chapter_path": self._compose_path(parent_chain, chapter), + "parent_chain": parent_chain, + "siblings": [self._snapshot(sib) for sib in siblings], + "previous": self._snapshot(prev_chapter), + "next": self._snapshot(next_chapter), + } + context["relation_hint"] = self._compose_relation_hint(context) + return context + + def _collect_parent_titles(self, chapter: Dict[str, Any]) -> List[str]: + chain: List[str] = [] + parent_id = chapter.get("parent_id") + while parent_id: + parent = self.chapter_map.get(parent_id) + if not parent: + break + chain.append(parent.get("title") or parent.get("raw_heading") or parent_id) + parent_id = parent.get("parent_id") + return list(reversed(chain)) + + def _compose_path(self, parent_chain: List[str], chapter: Dict[str, Any]) -> str: + segments = parent_chain + [chapter.get("title") or chapter.get("raw_heading") or chapter["id"]] + return " > ".join(segments) + + def _snapshot(self, chapter: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]: + if not chapter: + return None + return { + "id": chapter.get("id", ""), + "title": chapter.get("title") or chapter.get("raw_heading") or chapter.get("id", ""), + } + + def _compose_relation_hint(self, context: Dict[str, Any]) -> str: + hints: List[str] = [] + parent_chain = context.get("parent_chain") or [] + siblings = context.get("siblings") or [] + prev_chapter = context.get("previous") + next_chapter = context.get("next") + + if parent_chain: + hints.append(f"父级链路:{' > '.join(parent_chain)}") + if prev_chapter: + hints.append(f"前一章节:{prev_chapter['title']}") + if next_chapter: + hints.append(f"后一章节:{next_chapter['title']}") + if siblings: + sibling_titles = "、".join(sib["title"] for sib in siblings[:4]) + hints.append(f"同级章节:{sibling_titles}") + + return " | ".join(hints) if hints else "该章节为独立一级主题,可直接展开。" + + def _truncate(self, text: str, limit: int) -> str: + if len(text) <= limit: + return text + return text[: limit - 3] + "..." diff --git a/src/bidmaster/utils/rag_context.py b/src/bidmaster/utils/rag_context.py new file mode 100644 index 0000000..9ed454e --- /dev/null +++ b/src/bidmaster/utils/rag_context.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import math +import re +from typing import List + + +_WHITESPACE_RE = re.compile(r"\s+") + + +def normalize_text_for_dedup(text: str) -> str: + if not text: + return "" + return _WHITESPACE_RE.sub(" ", text).strip().lower() + + +def estimate_tokens(text: str) -> int: + """Estimate token usage without extra dependencies. + + Heuristic: + - CJK characters: ~1 token per char + - Non-CJK: ~1 token per 4 characters + """ + if not text: + return 0 + + cjk = 0 + for ch in text: + if "\u4e00" <= ch <= "\u9fff": + cjk += 1 + + non_cjk_len = len(text) - cjk + return cjk + int(math.ceil(non_cjk_len / 4)) + + +def truncate_to_token_budget(text: str, token_budget: int) -> str: + if not text or token_budget <= 0: + return "" + if estimate_tokens(text) <= token_budget: + return text + + lo, hi = 0, len(text) + while lo < hi: + mid = (lo + hi + 1) // 2 + if estimate_tokens(text[:mid]) <= token_budget: + lo = mid + else: + hi = mid - 1 + + return text[:lo].rstrip() + + +def fit_texts_to_token_budget( + texts: List[str], + token_budget: int, + *, + separator: str = "\n\n", +) -> List[str]: + if token_budget <= 0: + return [text for text in texts if (text or "").strip()] + + selected: List[str] = [] + used = 0 + sep_tokens = estimate_tokens(separator) + + for text in texts: + if not (text or "").strip(): + continue + + add_sep = sep_tokens if selected else 0 + text_tokens = estimate_tokens(text) + + if used + add_sep + text_tokens <= token_budget: + if selected: + used += sep_tokens + selected.append(text) + used += text_tokens + continue + + remaining = token_budget - used - add_sep + if remaining <= 0: + break + + truncated = truncate_to_token_budget(text, remaining) + if truncated: + if selected: + used += sep_tokens + selected.append(truncated) + break + + return selected diff --git a/tests/unit/test_rag_context_utils.py b/tests/unit/test_rag_context_utils.py new file mode 100644 index 0000000..6ef7799 --- /dev/null +++ b/tests/unit/test_rag_context_utils.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from bidmaster.utils.rag_context import ( + estimate_tokens, + fit_texts_to_token_budget, + truncate_to_token_budget, +) + + +def test_estimate_tokens_cjk_vs_ascii(): + assert estimate_tokens("中文") >= 2 + assert estimate_tokens("abcd") >= 1 + + +def test_truncate_to_token_budget_truncates(): + text = "中文中文中文" # 6 CJK chars + truncated = truncate_to_token_budget(text, 3) + assert truncated + assert estimate_tokens(truncated) <= 3 + + +def test_fit_texts_to_token_budget_drops_or_truncates_tail(): + first = "中文中文" # ~4 tokens + second = "中文中文中文" # ~6 tokens + budget = estimate_tokens(first) + fitted = fit_texts_to_token_budget([first, second], budget) + assert fitted == [first] diff --git a/tests/unit/test_rag_selection_filters.py b/tests/unit/test_rag_selection_filters.py new file mode 100644 index 0000000..e603e1c --- /dev/null +++ b/tests/unit/test_rag_selection_filters.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Any, Dict + +from bidmaster.nodes.base import NodeContext +from bidmaster.nodes.content.generate_content import GenerateContentNode, settings as node_settings +from bidmaster.utils.document_context import DocumentContextMatch +from bidmaster.utils.rag_context import estimate_tokens + + +class DummyTenderSearcher: + def __init__(self, matches: list[DocumentContextMatch]): + self.matches = matches + self.last_query: str | None = None + + def search(self, query: str, top_k: int): # pragma: no cover + self.last_query = query + return self.matches[:top_k] + + +class DummyRagTool: + def __init__(self, results: list[dict[str, Any]]): + self.results = results + self.last_context: Dict[str, Any] | None = None + + def search(self, query: str, k: int): # pragma: no cover + return self.results[:k] + + def generate_content(self, task_id: str, context: Dict[str, Any]): # pragma: no cover + self.last_context = context + return "generated" + + +def _base_state(chapter: Dict[str, Any], rag_tool: DummyRagTool) -> Dict[str, Any]: + return { + "current_chapter": chapter, + "chapter_queue": [chapter], + "chapter_configs": {}, + "chapter_children_map": {}, + "generated_contents": {}, + "rag_tool": rag_tool, + } + + +def test_kb_rag_similarity_threshold_filters_low_scores(monkeypatch): + monkeypatch.setattr(node_settings, "rag_search_top_k_max", 10) + monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.8) + monkeypatch.setattr(node_settings, "rag_context_dedup", False) + monkeypatch.setattr(node_settings, "rag_context_token_budget", 0) + + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + rag_tool = DummyRagTool( + results=[ + {"content": "LOW", "metadata": {"source": "kb"}, "score": 0.2}, + {"content": "HIGH", "metadata": {"source": "kb"}, "score": 0.95}, + ] + ) + state: Dict[str, Any] = _base_state(chapter, rag_tool) + state["available_rag_sources"] = {"global_kb": {"available": True}} + + node = GenerateContentNode() + node.execute(state, NodeContext()) + + rag_context = rag_tool.last_context.get("rag_context") + assert "HIGH" in rag_context + assert "LOW" not in rag_context + + +def test_rag_context_dedup_removes_duplicate_texts_across_sources(monkeypatch): + monkeypatch.setattr(node_settings, "rag_context_dedup", True) + monkeypatch.setattr(node_settings, "rag_context_token_budget", 0) + monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0) + monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0) + + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + rag_tool = DummyRagTool( + results=[ + {"content": "DUP", "metadata": {"source": "kb"}, "score": 0.9}, + ] + ) + state: Dict[str, Any] = _base_state(chapter, rag_tool) + state["available_rag_sources"] = { + "tender_doc": {"available": True}, + "global_kb": {"available": True}, + } + state["tender_doc_searcher"] = DummyTenderSearcher( + matches=[ + DocumentContextMatch( + text="DUP", + section="章节一", + score=0.9, + source_type="paragraph", + metadata={}, + ) + ] + ) + + node = GenerateContentNode() + node.execute(state, NodeContext()) + + rag_context = rag_tool.last_context.get("rag_context") + assert rag_context.count("DUP") == 1 + + +def test_rag_context_token_budget_limits_total_context(monkeypatch): + monkeypatch.setattr(node_settings, "rag_context_dedup", False) + monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0) + monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0) + + first = "【招标文件】章节一\nA" + budget = estimate_tokens(first) + monkeypatch.setattr(node_settings, "rag_context_token_budget", budget) + + chapter = {"id": "chapter_1", "title": "概述", "level": 1} + rag_tool = DummyRagTool( + results=[ + {"content": "B", "metadata": {"source": "kb"}, "score": 0.9}, + ] + ) + state: Dict[str, Any] = _base_state(chapter, rag_tool) + state["available_rag_sources"] = { + "tender_doc": {"available": True}, + "global_kb": {"available": True}, + } + state["tender_doc_searcher"] = DummyTenderSearcher( + matches=[ + DocumentContextMatch( + text="A", + section="章节一", + score=0.9, + source_type="paragraph", + metadata={}, + ) + ] + ) + + node = GenerateContentNode() + node.execute(state, NodeContext()) + + rag_context = rag_tool.last_context.get("rag_context") + assert "A" in rag_context + assert "【知识库】" not in rag_context