feat: add rag context and filters,增加了对rag提取的加强

This commit is contained in:
sladro 2025-12-19 13:35:04 +08:00
parent cdc9b1d757
commit 3d12f9d065
10 changed files with 503 additions and 24 deletions

View File

@ -43,4 +43,16 @@ interaction:
# 日志设置
logging:
level: INFO
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# RAG 检索与上下文控制(注意:这些必须是顶层字段,才能映射到 Settings
rag_search_top_k: 3
rag_search_top_k_max: 20
rag_similarity_threshold: 0.0
tender_doc_search_top_k: 3
tender_doc_search_top_k_max: 20
tender_doc_similarity_threshold: 0.2
rag_context_dedup: true
rag_context_token_budget: 1200

View File

@ -297,7 +297,7 @@ content_prompts:
要求:
1. 内容专业、详实,符合招标文件要求
2. 突出技术优势和实施能力
2. 突出技术优势和实施能力,在技术上,流程上一定要保证各个环节章节的一致性
3. 严格遵守写作模式与篇幅要求: 有子标题仅写父级统领概要(不展开子标题正文),无子标题输出完整正文
4. 严禁新增任何章/节级标题或“商务条款、技术偏差、响应情况”等模板段,如需结构化仅使用普通段落或加粗语句
5. 开头不得出现“经认真研读招标文件要求”“偏差说明如下”等跨章节套话,内容必须围绕《{title}》本身展开

View File

@ -83,7 +83,25 @@ class Settings(BaseSettings):
# Word填充配置
max_sub_chapter_level: int = Field(default=3, description="子章节最大层级")
rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量")
rag_search_top_k_max: int = Field(
default=0,
description="RAG检索最大候选数量(0表示使用rag_search_top_k)",
)
rag_similarity_threshold: float = Field(
default=0.0,
description="知识库RAG相似度阈值(<=0表示不启用)",
)
tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量")
tender_doc_search_top_k_max: int = Field(
default=0,
description="招标文档检索最大候选数量(0表示使用tender_doc_search_top_k)",
)
tender_doc_similarity_threshold: float = Field(
default=0.0,
description="招标文档RAG相似度阈值(<=0表示不启用)",
)
rag_context_dedup: bool = Field(default=False, description="是否对RAG片段按文本去重")
rag_context_token_budget: int = Field(default=0, description="RAG上下文token预算(0表示不启用)")
parent_context_length: int = Field(default=500, description="父章节上下文长度限制")
# Word格式配置

View File

@ -4,16 +4,25 @@
"""
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from ..base import BaseNode, NodeContext
from ...config.settings import get_settings
from ...utils.rag_context import fit_texts_to_token_budget, normalize_text_for_dedup
from ...utils.prompt_planner import PromptPlanner
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass(frozen=True)
class _RagCandidate:
display_text: str
dedup_text: str
score: float
class GenerateContentNode(BaseNode):
"""内容生成节点
@ -210,16 +219,17 @@ class GenerateContentNode(BaseNode):
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
config["rag_sources"] = rag_sources
aggregated_context: List[str] = []
aggregated_candidates: List[_RagCandidate] = []
for source_id in rag_sources:
if source_id == "tender_doc":
aggregated_context.extend(self._search_tender_doc(state, query))
aggregated_candidates.extend(self._search_tender_doc(state, query))
elif source_id == "global_kb":
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
aggregated_candidates.extend(self._search_global_kb(state, rag_tool, query))
else:
logger.warning("未知的RAG来源: %s", source_id)
generation_context["rag_context"] = "\n\n".join(aggregated_context) if aggregated_context else ""
rag_texts = self._finalize_rag_context(aggregated_candidates)
generation_context["rag_context"] = "\n\n".join(rag_texts) if rag_texts else ""
# 调用生成方法
try:
@ -242,54 +252,123 @@ class GenerateContentNode(BaseNode):
return ["global_kb"]
return []
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]:
def _finalize_rag_context(self, candidates: List[_RagCandidate]) -> List[str]:
if not candidates:
return []
filtered = candidates
if getattr(settings, "rag_context_dedup", False):
seen: set[str] = set()
deduped: List[_RagCandidate] = []
for item in filtered:
key = normalize_text_for_dedup(item.dedup_text)
if not key or key in seen:
continue
seen.add(key)
deduped.append(item)
filtered = deduped
texts = [item.display_text for item in filtered]
budget = int(getattr(settings, "rag_context_token_budget", 0) or 0)
texts = fit_texts_to_token_budget(texts, budget)
if getattr(settings, "rag_context_dedup", False) or budget > 0:
logger.info(
"RAG上下文最终选取 %s 条片段(候选 %s 条,去重=%stoken预算=%s",
len(texts),
len(candidates),
bool(getattr(settings, "rag_context_dedup", False)),
budget,
)
return texts
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[_RagCandidate]:
searcher = state.get("tender_doc_searcher")
if not searcher:
return []
top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
base_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
max_k = getattr(settings, "tender_doc_search_top_k_max", 0) or base_k
max_k = max(int(max_k), 1)
threshold = float(getattr(settings, "tender_doc_similarity_threshold", 0.0) or 0.0)
try:
matches = searcher.search(query, top_k)
matches = searcher.search(query, max_k)
except Exception as exc:
logger.warning("招标文件RAG检索失败: %s", exc)
return []
contexts: List[str] = []
selected: List[_RagCandidate] = []
for match in matches:
section = getattr(match, "section", "招标文件") or "招标文件"
text = getattr(match, "text", "")
score = float(getattr(match, "score", 0.0) or 0.0)
if not text:
continue
contexts.append(f"【招标文件】{section}\n{text}")
if threshold > 0 and score < threshold:
continue
selected.append(
_RagCandidate(
display_text=f"【招标文件】{section}\n{text}",
dedup_text=text,
score=score,
)
)
if contexts:
logger.info("招标文件RAG检索到 %s 条内容", len(contexts))
if selected:
logger.info(
"招标文件RAG选取 %s 条内容(候选 %smax=%s,阈值=%s",
len(selected),
len(matches or []),
max_k,
threshold,
)
else:
logger.warning("招标文件RAG未检索到相关内容")
return contexts
return selected
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]:
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[_RagCandidate]:
if not rag_tool:
return []
base_k = int(getattr(settings, "rag_search_top_k", 3) or 3)
max_k = getattr(settings, "rag_search_top_k_max", 0) or base_k
max_k = max(int(max_k), 1)
threshold = float(getattr(settings, "rag_similarity_threshold", 0.0) or 0.0)
try:
results = rag_tool.search(query, k=settings.rag_search_top_k)
results = rag_tool.search(query, k=max_k)
except Exception as exc:
logger.warning("知识库RAG检索失败: %s", exc)
return []
contexts: List[str] = []
selected: List[_RagCandidate] = []
for result in results or []:
content = result.get("content")
if not content:
continue
source = (result.get("metadata") or {}).get("source") or "知识库片段"
contexts.append(f"【知识库】{source}\n{content}")
score = float(result.get("score") or 0.0)
if threshold > 0 and score < threshold:
continue
selected.append(
_RagCandidate(
display_text=f"【知识库】{source}\n{content}",
dedup_text=content,
score=score,
)
)
if contexts:
logger.info("知识库RAG检索到 %s 条内容", len(contexts))
if selected:
logger.info(
"知识库RAG选取 %s 条内容(候选 %smax=%s,阈值=%s",
len(selected),
len(results or []),
max_k,
threshold,
)
else:
logger.warning("知识库RAG未检索到相关内容")
return contexts
return selected

View File

@ -107,7 +107,7 @@ class WordProcessor:
# 使用Word的标题样式
if chapter.level <= MAX_HEADING_LEVEL:
heading = doc.add_heading(title_text, level=chapter.level)
doc.add_heading(title_text, level=chapter.level)
else:
# 超过最大层级用普通段落加粗
para = doc.add_paragraph()
@ -119,7 +119,7 @@ class WordProcessor:
# 为有内容的章节添加占位符
if chapter.template_placeholder:
content_para = doc.add_paragraph(f"\n{chapter.template_placeholder}\n")
doc.add_paragraph(f"\n{chapter.template_placeholder}\n")
# 添加写作指导
if chapter.score and chapter.score > SCORE_THRESHOLD:

View File

@ -67,7 +67,6 @@ class WordFormatter:
for pattern, style in patterns:
match = re.match(f'^{pattern}$', part.strip())
if match:
number_part = match.group(1)
suffix = match.group(2) if match.lastindex >= 2 else ""
format_dict[level] = {

View File

@ -0,0 +1,111 @@
"""生成章节目录上下文信息。"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
class OutlineContextBuilder:
"""根据章节列表构建目录摘要与章节上下文。"""
def __init__(self, chapters: Optional[List[Dict[str, Any]]]):
self.chapters = chapters or []
self.chapter_map = {ch["id"]: ch for ch in self.chapters}
self.children_map = self._build_children_map()
def build(self) -> Dict[str, Any]:
outline_summary = self._build_outline_summary()
contexts: Dict[str, Dict[str, Any]] = {}
for index, chapter in enumerate(self.chapters):
contexts[chapter["id"]] = self._build_chapter_context(chapter, index)
return {
"outline_summary": outline_summary,
"chapter_contexts": contexts,
}
def _build_children_map(self) -> Dict[str, List[Dict[str, Any]]]:
result: Dict[str, List[Dict[str, Any]]] = {}
for chapter in self.chapters:
parent_id = chapter.get("parent_id")
if not parent_id:
continue
result.setdefault(parent_id, []).append(chapter)
for siblings in result.values():
siblings.sort(key=lambda item: item.get("order_index", 0))
return result
def _build_outline_summary(self, limit: int = 1600) -> str:
lines: List[str] = []
for chapter in self.chapters:
level = max(chapter.get("level", 1) - 1, 0)
indent = " " * level
heading = chapter.get("heading_number")
label = chapter.get("title") or chapter.get("raw_heading") or chapter["id"]
if heading:
label = f"{heading} {label}"
lines.append(f"{indent}- {label}")
summary = "\n".join(lines)
return self._truncate(summary, limit)
def _build_chapter_context(self, chapter: Dict[str, Any], index: int) -> Dict[str, Any]:
parent_chain = self._collect_parent_titles(chapter)
siblings = [c for c in self.children_map.get(chapter.get("parent_id"), []) if c["id"] != chapter["id"]]
prev_chapter = self.chapters[index - 1] if index > 0 else None
next_chapter = self.chapters[index + 1] if index + 1 < len(self.chapters) else None
context = {
"chapter_path": self._compose_path(parent_chain, chapter),
"parent_chain": parent_chain,
"siblings": [self._snapshot(sib) for sib in siblings],
"previous": self._snapshot(prev_chapter),
"next": self._snapshot(next_chapter),
}
context["relation_hint"] = self._compose_relation_hint(context)
return context
def _collect_parent_titles(self, chapter: Dict[str, Any]) -> List[str]:
chain: List[str] = []
parent_id = chapter.get("parent_id")
while parent_id:
parent = self.chapter_map.get(parent_id)
if not parent:
break
chain.append(parent.get("title") or parent.get("raw_heading") or parent_id)
parent_id = parent.get("parent_id")
return list(reversed(chain))
def _compose_path(self, parent_chain: List[str], chapter: Dict[str, Any]) -> str:
segments = parent_chain + [chapter.get("title") or chapter.get("raw_heading") or chapter["id"]]
return " > ".join(segments)
def _snapshot(self, chapter: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]:
if not chapter:
return None
return {
"id": chapter.get("id", ""),
"title": chapter.get("title") or chapter.get("raw_heading") or chapter.get("id", ""),
}
def _compose_relation_hint(self, context: Dict[str, Any]) -> str:
hints: List[str] = []
parent_chain = context.get("parent_chain") or []
siblings = context.get("siblings") or []
prev_chapter = context.get("previous")
next_chapter = context.get("next")
if parent_chain:
hints.append(f"父级链路:{' > '.join(parent_chain)}")
if prev_chapter:
hints.append(f"前一章节:{prev_chapter['title']}")
if next_chapter:
hints.append(f"后一章节:{next_chapter['title']}")
if siblings:
sibling_titles = "".join(sib["title"] for sib in siblings[:4])
hints.append(f"同级章节:{sibling_titles}")
return " | ".join(hints) if hints else "该章节为独立一级主题,可直接展开。"
def _truncate(self, text: str, limit: int) -> str:
if len(text) <= limit:
return text
return text[: limit - 3] + "..."

View File

@ -0,0 +1,91 @@
from __future__ import annotations
import math
import re
from typing import List
_WHITESPACE_RE = re.compile(r"\s+")
def normalize_text_for_dedup(text: str) -> str:
if not text:
return ""
return _WHITESPACE_RE.sub(" ", text).strip().lower()
def estimate_tokens(text: str) -> int:
"""Estimate token usage without extra dependencies.
Heuristic:
- CJK characters: ~1 token per char
- Non-CJK: ~1 token per 4 characters
"""
if not text:
return 0
cjk = 0
for ch in text:
if "\u4e00" <= ch <= "\u9fff":
cjk += 1
non_cjk_len = len(text) - cjk
return cjk + int(math.ceil(non_cjk_len / 4))
def truncate_to_token_budget(text: str, token_budget: int) -> str:
if not text or token_budget <= 0:
return ""
if estimate_tokens(text) <= token_budget:
return text
lo, hi = 0, len(text)
while lo < hi:
mid = (lo + hi + 1) // 2
if estimate_tokens(text[:mid]) <= token_budget:
lo = mid
else:
hi = mid - 1
return text[:lo].rstrip()
def fit_texts_to_token_budget(
texts: List[str],
token_budget: int,
*,
separator: str = "\n\n",
) -> List[str]:
if token_budget <= 0:
return [text for text in texts if (text or "").strip()]
selected: List[str] = []
used = 0
sep_tokens = estimate_tokens(separator)
for text in texts:
if not (text or "").strip():
continue
add_sep = sep_tokens if selected else 0
text_tokens = estimate_tokens(text)
if used + add_sep + text_tokens <= token_budget:
if selected:
used += sep_tokens
selected.append(text)
used += text_tokens
continue
remaining = token_budget - used - add_sep
if remaining <= 0:
break
truncated = truncate_to_token_budget(text, remaining)
if truncated:
if selected:
used += sep_tokens
selected.append(truncated)
break
return selected

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from bidmaster.utils.rag_context import (
estimate_tokens,
fit_texts_to_token_budget,
truncate_to_token_budget,
)
def test_estimate_tokens_cjk_vs_ascii():
assert estimate_tokens("中文") >= 2
assert estimate_tokens("abcd") >= 1
def test_truncate_to_token_budget_truncates():
text = "中文中文中文" # 6 CJK chars
truncated = truncate_to_token_budget(text, 3)
assert truncated
assert estimate_tokens(truncated) <= 3
def test_fit_texts_to_token_budget_drops_or_truncates_tail():
first = "中文中文" # ~4 tokens
second = "中文中文中文" # ~6 tokens
budget = estimate_tokens(first)
fitted = fit_texts_to_token_budget([first, second], budget)
assert fitted == [first]

View File

@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Any, Dict
from bidmaster.nodes.base import NodeContext
from bidmaster.nodes.content.generate_content import GenerateContentNode, settings as node_settings
from bidmaster.utils.document_context import DocumentContextMatch
from bidmaster.utils.rag_context import estimate_tokens
class DummyTenderSearcher:
def __init__(self, matches: list[DocumentContextMatch]):
self.matches = matches
self.last_query: str | None = None
def search(self, query: str, top_k: int): # pragma: no cover
self.last_query = query
return self.matches[:top_k]
class DummyRagTool:
def __init__(self, results: list[dict[str, Any]]):
self.results = results
self.last_context: Dict[str, Any] | None = None
def search(self, query: str, k: int): # pragma: no cover
return self.results[:k]
def generate_content(self, task_id: str, context: Dict[str, Any]): # pragma: no cover
self.last_context = context
return "generated"
def _base_state(chapter: Dict[str, Any], rag_tool: DummyRagTool) -> Dict[str, Any]:
return {
"current_chapter": chapter,
"chapter_queue": [chapter],
"chapter_configs": {},
"chapter_children_map": {},
"generated_contents": {},
"rag_tool": rag_tool,
}
def test_kb_rag_similarity_threshold_filters_low_scores(monkeypatch):
monkeypatch.setattr(node_settings, "rag_search_top_k_max", 10)
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.8)
monkeypatch.setattr(node_settings, "rag_context_dedup", False)
monkeypatch.setattr(node_settings, "rag_context_token_budget", 0)
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool(
results=[
{"content": "LOW", "metadata": {"source": "kb"}, "score": 0.2},
{"content": "HIGH", "metadata": {"source": "kb"}, "score": 0.95},
]
)
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {"global_kb": {"available": True}}
node = GenerateContentNode()
node.execute(state, NodeContext())
rag_context = rag_tool.last_context.get("rag_context")
assert "HIGH" in rag_context
assert "LOW" not in rag_context
def test_rag_context_dedup_removes_duplicate_texts_across_sources(monkeypatch):
monkeypatch.setattr(node_settings, "rag_context_dedup", True)
monkeypatch.setattr(node_settings, "rag_context_token_budget", 0)
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0)
monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0)
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool(
results=[
{"content": "DUP", "metadata": {"source": "kb"}, "score": 0.9},
]
)
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {
"tender_doc": {"available": True},
"global_kb": {"available": True},
}
state["tender_doc_searcher"] = DummyTenderSearcher(
matches=[
DocumentContextMatch(
text="DUP",
section="章节一",
score=0.9,
source_type="paragraph",
metadata={},
)
]
)
node = GenerateContentNode()
node.execute(state, NodeContext())
rag_context = rag_tool.last_context.get("rag_context")
assert rag_context.count("DUP") == 1
def test_rag_context_token_budget_limits_total_context(monkeypatch):
monkeypatch.setattr(node_settings, "rag_context_dedup", False)
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0)
monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0)
first = "【招标文件】章节一\nA"
budget = estimate_tokens(first)
monkeypatch.setattr(node_settings, "rag_context_token_budget", budget)
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
rag_tool = DummyRagTool(
results=[
{"content": "B", "metadata": {"source": "kb"}, "score": 0.9},
]
)
state: Dict[str, Any] = _base_state(chapter, rag_tool)
state["available_rag_sources"] = {
"tender_doc": {"available": True},
"global_kb": {"available": True},
}
state["tender_doc_searcher"] = DummyTenderSearcher(
matches=[
DocumentContextMatch(
text="A",
section="章节一",
score=0.9,
source_type="paragraph",
metadata={},
)
]
)
node = GenerateContentNode()
node.execute(state, NodeContext())
rag_context = rag_tool.last_context.get("rag_context")
assert "A" in rag_context
assert "【知识库】" not in rag_context