feat: add rag context and filters,增加了对rag提取的加强
This commit is contained in:
parent
cdc9b1d757
commit
3d12f9d065
@ -43,4 +43,16 @@ interaction:
|
||||
# 日志设置
|
||||
logging:
|
||||
level: INFO
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# RAG 检索与上下文控制(注意:这些必须是顶层字段,才能映射到 Settings)
|
||||
rag_search_top_k: 3
|
||||
rag_search_top_k_max: 20
|
||||
rag_similarity_threshold: 0.0
|
||||
|
||||
tender_doc_search_top_k: 3
|
||||
tender_doc_search_top_k_max: 20
|
||||
tender_doc_similarity_threshold: 0.2
|
||||
|
||||
rag_context_dedup: true
|
||||
rag_context_token_budget: 1200
|
||||
@ -297,7 +297,7 @@ content_prompts:
|
||||
|
||||
要求:
|
||||
1. 内容专业、详实,符合招标文件要求
|
||||
2. 突出技术优势和实施能力
|
||||
2. 突出技术优势和实施能力,在技术上,流程上一定要保证各个环节章节的一致性
|
||||
3. 严格遵守写作模式与篇幅要求: 有子标题仅写父级统领概要(不展开子标题正文),无子标题输出完整正文
|
||||
4. 严禁新增任何章/节级标题或“商务条款、技术偏差、响应情况”等模板段,如需结构化仅使用普通段落或加粗语句
|
||||
5. 开头不得出现“经认真研读招标文件要求”“偏差说明如下”等跨章节套话,内容必须围绕《{title}》本身展开
|
||||
|
||||
@ -83,7 +83,25 @@ class Settings(BaseSettings):
|
||||
# Word填充配置
|
||||
max_sub_chapter_level: int = Field(default=3, description="子章节最大层级")
|
||||
rag_search_top_k: int = Field(default=3, description="RAG检索返回结果数量")
|
||||
rag_search_top_k_max: int = Field(
|
||||
default=0,
|
||||
description="RAG检索最大候选数量(0表示使用rag_search_top_k)",
|
||||
)
|
||||
rag_similarity_threshold: float = Field(
|
||||
default=0.0,
|
||||
description="知识库RAG相似度阈值(<=0表示不启用)",
|
||||
)
|
||||
tender_doc_search_top_k: int = Field(default=3, description="招标文档检索返回结果数量")
|
||||
tender_doc_search_top_k_max: int = Field(
|
||||
default=0,
|
||||
description="招标文档检索最大候选数量(0表示使用tender_doc_search_top_k)",
|
||||
)
|
||||
tender_doc_similarity_threshold: float = Field(
|
||||
default=0.0,
|
||||
description="招标文档RAG相似度阈值(<=0表示不启用)",
|
||||
)
|
||||
rag_context_dedup: bool = Field(default=False, description="是否对RAG片段按文本去重")
|
||||
rag_context_token_budget: int = Field(default=0, description="RAG上下文token预算(0表示不启用)")
|
||||
parent_context_length: int = Field(default=500, description="父章节上下文长度限制")
|
||||
|
||||
# Word格式配置
|
||||
|
||||
@ -4,16 +4,25 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..base import BaseNode, NodeContext
|
||||
from ...config.settings import get_settings
|
||||
from ...utils.rag_context import fit_texts_to_token_budget, normalize_text_for_dedup
|
||||
from ...utils.prompt_planner import PromptPlanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RagCandidate:
|
||||
display_text: str
|
||||
dedup_text: str
|
||||
score: float
|
||||
|
||||
|
||||
class GenerateContentNode(BaseNode):
|
||||
"""内容生成节点
|
||||
|
||||
@ -210,16 +219,17 @@ class GenerateContentNode(BaseNode):
|
||||
rag_sources = config.get("rag_sources") or self._default_rag_sources(state)
|
||||
config["rag_sources"] = rag_sources
|
||||
|
||||
aggregated_context: List[str] = []
|
||||
aggregated_candidates: List[_RagCandidate] = []
|
||||
for source_id in rag_sources:
|
||||
if source_id == "tender_doc":
|
||||
aggregated_context.extend(self._search_tender_doc(state, query))
|
||||
aggregated_candidates.extend(self._search_tender_doc(state, query))
|
||||
elif source_id == "global_kb":
|
||||
aggregated_context.extend(self._search_global_kb(state, rag_tool, query))
|
||||
aggregated_candidates.extend(self._search_global_kb(state, rag_tool, query))
|
||||
else:
|
||||
logger.warning("未知的RAG来源: %s", source_id)
|
||||
|
||||
generation_context["rag_context"] = "\n\n".join(aggregated_context) if aggregated_context else ""
|
||||
rag_texts = self._finalize_rag_context(aggregated_candidates)
|
||||
generation_context["rag_context"] = "\n\n".join(rag_texts) if rag_texts else ""
|
||||
|
||||
# 调用生成方法
|
||||
try:
|
||||
@ -242,54 +252,123 @@ class GenerateContentNode(BaseNode):
|
||||
return ["global_kb"]
|
||||
return []
|
||||
|
||||
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[str]:
|
||||
def _finalize_rag_context(self, candidates: List[_RagCandidate]) -> List[str]:
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
filtered = candidates
|
||||
if getattr(settings, "rag_context_dedup", False):
|
||||
seen: set[str] = set()
|
||||
deduped: List[_RagCandidate] = []
|
||||
for item in filtered:
|
||||
key = normalize_text_for_dedup(item.dedup_text)
|
||||
if not key or key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
filtered = deduped
|
||||
|
||||
texts = [item.display_text for item in filtered]
|
||||
budget = int(getattr(settings, "rag_context_token_budget", 0) or 0)
|
||||
texts = fit_texts_to_token_budget(texts, budget)
|
||||
|
||||
if getattr(settings, "rag_context_dedup", False) or budget > 0:
|
||||
logger.info(
|
||||
"RAG上下文最终选取 %s 条片段(候选 %s 条,去重=%s,token预算=%s)",
|
||||
len(texts),
|
||||
len(candidates),
|
||||
bool(getattr(settings, "rag_context_dedup", False)),
|
||||
budget,
|
||||
)
|
||||
|
||||
return texts
|
||||
|
||||
def _search_tender_doc(self, state: Dict[str, Any], query: str) -> List[_RagCandidate]:
|
||||
searcher = state.get("tender_doc_searcher")
|
||||
if not searcher:
|
||||
return []
|
||||
|
||||
top_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
|
||||
base_k = getattr(settings, "tender_doc_search_top_k", settings.rag_search_top_k)
|
||||
max_k = getattr(settings, "tender_doc_search_top_k_max", 0) or base_k
|
||||
max_k = max(int(max_k), 1)
|
||||
threshold = float(getattr(settings, "tender_doc_similarity_threshold", 0.0) or 0.0)
|
||||
try:
|
||||
matches = searcher.search(query, top_k)
|
||||
matches = searcher.search(query, max_k)
|
||||
except Exception as exc:
|
||||
logger.warning("招标文件RAG检索失败: %s", exc)
|
||||
return []
|
||||
|
||||
contexts: List[str] = []
|
||||
selected: List[_RagCandidate] = []
|
||||
for match in matches:
|
||||
section = getattr(match, "section", "招标文件") or "招标文件"
|
||||
text = getattr(match, "text", "")
|
||||
score = float(getattr(match, "score", 0.0) or 0.0)
|
||||
if not text:
|
||||
continue
|
||||
contexts.append(f"【招标文件】{section}\n{text}")
|
||||
if threshold > 0 and score < threshold:
|
||||
continue
|
||||
selected.append(
|
||||
_RagCandidate(
|
||||
display_text=f"【招标文件】{section}\n{text}",
|
||||
dedup_text=text,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if contexts:
|
||||
logger.info("招标文件RAG检索到 %s 条内容", len(contexts))
|
||||
if selected:
|
||||
logger.info(
|
||||
"招标文件RAG选取 %s 条内容(候选 %s 条,max=%s,阈值=%s)",
|
||||
len(selected),
|
||||
len(matches or []),
|
||||
max_k,
|
||||
threshold,
|
||||
)
|
||||
else:
|
||||
logger.warning("招标文件RAG未检索到相关内容")
|
||||
|
||||
return contexts
|
||||
return selected
|
||||
|
||||
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[str]:
|
||||
def _search_global_kb(self, state: Dict[str, Any], rag_tool, query: str) -> List[_RagCandidate]:
|
||||
if not rag_tool:
|
||||
return []
|
||||
|
||||
base_k = int(getattr(settings, "rag_search_top_k", 3) or 3)
|
||||
max_k = getattr(settings, "rag_search_top_k_max", 0) or base_k
|
||||
max_k = max(int(max_k), 1)
|
||||
threshold = float(getattr(settings, "rag_similarity_threshold", 0.0) or 0.0)
|
||||
|
||||
try:
|
||||
results = rag_tool.search(query, k=settings.rag_search_top_k)
|
||||
results = rag_tool.search(query, k=max_k)
|
||||
except Exception as exc:
|
||||
logger.warning("知识库RAG检索失败: %s", exc)
|
||||
return []
|
||||
|
||||
contexts: List[str] = []
|
||||
selected: List[_RagCandidate] = []
|
||||
for result in results or []:
|
||||
content = result.get("content")
|
||||
if not content:
|
||||
continue
|
||||
source = (result.get("metadata") or {}).get("source") or "知识库片段"
|
||||
contexts.append(f"【知识库】{source}\n{content}")
|
||||
score = float(result.get("score") or 0.0)
|
||||
if threshold > 0 and score < threshold:
|
||||
continue
|
||||
selected.append(
|
||||
_RagCandidate(
|
||||
display_text=f"【知识库】{source}\n{content}",
|
||||
dedup_text=content,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if contexts:
|
||||
logger.info("知识库RAG检索到 %s 条内容", len(contexts))
|
||||
if selected:
|
||||
logger.info(
|
||||
"知识库RAG选取 %s 条内容(候选 %s 条,max=%s,阈值=%s)",
|
||||
len(selected),
|
||||
len(results or []),
|
||||
max_k,
|
||||
threshold,
|
||||
)
|
||||
else:
|
||||
logger.warning("知识库RAG未检索到相关内容")
|
||||
|
||||
return contexts
|
||||
return selected
|
||||
@ -107,7 +107,7 @@ class WordProcessor:
|
||||
|
||||
# 使用Word的标题样式
|
||||
if chapter.level <= MAX_HEADING_LEVEL:
|
||||
heading = doc.add_heading(title_text, level=chapter.level)
|
||||
doc.add_heading(title_text, level=chapter.level)
|
||||
else:
|
||||
# 超过最大层级用普通段落加粗
|
||||
para = doc.add_paragraph()
|
||||
@ -119,7 +119,7 @@ class WordProcessor:
|
||||
|
||||
# 为有内容的章节添加占位符
|
||||
if chapter.template_placeholder:
|
||||
content_para = doc.add_paragraph(f"\n{chapter.template_placeholder}\n")
|
||||
doc.add_paragraph(f"\n{chapter.template_placeholder}\n")
|
||||
|
||||
# 添加写作指导
|
||||
if chapter.score and chapter.score > SCORE_THRESHOLD:
|
||||
|
||||
@ -67,7 +67,6 @@ class WordFormatter:
|
||||
for pattern, style in patterns:
|
||||
match = re.match(f'^{pattern}$', part.strip())
|
||||
if match:
|
||||
number_part = match.group(1)
|
||||
suffix = match.group(2) if match.lastindex >= 2 else ""
|
||||
|
||||
format_dict[level] = {
|
||||
|
||||
111
src/bidmaster/utils/outline_context.py
Normal file
111
src/bidmaster/utils/outline_context.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""生成章节目录上下文信息。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class OutlineContextBuilder:
|
||||
"""根据章节列表构建目录摘要与章节上下文。"""
|
||||
|
||||
def __init__(self, chapters: Optional[List[Dict[str, Any]]]):
|
||||
self.chapters = chapters or []
|
||||
self.chapter_map = {ch["id"]: ch for ch in self.chapters}
|
||||
self.children_map = self._build_children_map()
|
||||
|
||||
def build(self) -> Dict[str, Any]:
|
||||
outline_summary = self._build_outline_summary()
|
||||
contexts: Dict[str, Dict[str, Any]] = {}
|
||||
for index, chapter in enumerate(self.chapters):
|
||||
contexts[chapter["id"]] = self._build_chapter_context(chapter, index)
|
||||
return {
|
||||
"outline_summary": outline_summary,
|
||||
"chapter_contexts": contexts,
|
||||
}
|
||||
|
||||
def _build_children_map(self) -> Dict[str, List[Dict[str, Any]]]:
|
||||
result: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for chapter in self.chapters:
|
||||
parent_id = chapter.get("parent_id")
|
||||
if not parent_id:
|
||||
continue
|
||||
result.setdefault(parent_id, []).append(chapter)
|
||||
for siblings in result.values():
|
||||
siblings.sort(key=lambda item: item.get("order_index", 0))
|
||||
return result
|
||||
|
||||
def _build_outline_summary(self, limit: int = 1600) -> str:
|
||||
lines: List[str] = []
|
||||
for chapter in self.chapters:
|
||||
level = max(chapter.get("level", 1) - 1, 0)
|
||||
indent = " " * level
|
||||
heading = chapter.get("heading_number")
|
||||
label = chapter.get("title") or chapter.get("raw_heading") or chapter["id"]
|
||||
if heading:
|
||||
label = f"{heading} {label}"
|
||||
lines.append(f"{indent}- {label}")
|
||||
summary = "\n".join(lines)
|
||||
return self._truncate(summary, limit)
|
||||
|
||||
def _build_chapter_context(self, chapter: Dict[str, Any], index: int) -> Dict[str, Any]:
|
||||
parent_chain = self._collect_parent_titles(chapter)
|
||||
siblings = [c for c in self.children_map.get(chapter.get("parent_id"), []) if c["id"] != chapter["id"]]
|
||||
prev_chapter = self.chapters[index - 1] if index > 0 else None
|
||||
next_chapter = self.chapters[index + 1] if index + 1 < len(self.chapters) else None
|
||||
|
||||
context = {
|
||||
"chapter_path": self._compose_path(parent_chain, chapter),
|
||||
"parent_chain": parent_chain,
|
||||
"siblings": [self._snapshot(sib) for sib in siblings],
|
||||
"previous": self._snapshot(prev_chapter),
|
||||
"next": self._snapshot(next_chapter),
|
||||
}
|
||||
context["relation_hint"] = self._compose_relation_hint(context)
|
||||
return context
|
||||
|
||||
def _collect_parent_titles(self, chapter: Dict[str, Any]) -> List[str]:
|
||||
chain: List[str] = []
|
||||
parent_id = chapter.get("parent_id")
|
||||
while parent_id:
|
||||
parent = self.chapter_map.get(parent_id)
|
||||
if not parent:
|
||||
break
|
||||
chain.append(parent.get("title") or parent.get("raw_heading") or parent_id)
|
||||
parent_id = parent.get("parent_id")
|
||||
return list(reversed(chain))
|
||||
|
||||
def _compose_path(self, parent_chain: List[str], chapter: Dict[str, Any]) -> str:
|
||||
segments = parent_chain + [chapter.get("title") or chapter.get("raw_heading") or chapter["id"]]
|
||||
return " > ".join(segments)
|
||||
|
||||
def _snapshot(self, chapter: Optional[Dict[str, Any]]) -> Optional[Dict[str, str]]:
|
||||
if not chapter:
|
||||
return None
|
||||
return {
|
||||
"id": chapter.get("id", ""),
|
||||
"title": chapter.get("title") or chapter.get("raw_heading") or chapter.get("id", ""),
|
||||
}
|
||||
|
||||
def _compose_relation_hint(self, context: Dict[str, Any]) -> str:
|
||||
hints: List[str] = []
|
||||
parent_chain = context.get("parent_chain") or []
|
||||
siblings = context.get("siblings") or []
|
||||
prev_chapter = context.get("previous")
|
||||
next_chapter = context.get("next")
|
||||
|
||||
if parent_chain:
|
||||
hints.append(f"父级链路:{' > '.join(parent_chain)}")
|
||||
if prev_chapter:
|
||||
hints.append(f"前一章节:{prev_chapter['title']}")
|
||||
if next_chapter:
|
||||
hints.append(f"后一章节:{next_chapter['title']}")
|
||||
if siblings:
|
||||
sibling_titles = "、".join(sib["title"] for sib in siblings[:4])
|
||||
hints.append(f"同级章节:{sibling_titles}")
|
||||
|
||||
return " | ".join(hints) if hints else "该章节为独立一级主题,可直接展开。"
|
||||
|
||||
def _truncate(self, text: str, limit: int) -> str:
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[: limit - 3] + "..."
|
||||
91
src/bidmaster/utils/rag_context.py
Normal file
91
src/bidmaster/utils/rag_context.py
Normal file
@ -0,0 +1,91 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
_WHITESPACE_RE = re.compile(r"\s+")
|
||||
|
||||
|
||||
def normalize_text_for_dedup(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
return _WHITESPACE_RE.sub(" ", text).strip().lower()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Estimate token usage without extra dependencies.
|
||||
|
||||
Heuristic:
|
||||
- CJK characters: ~1 token per char
|
||||
- Non-CJK: ~1 token per 4 characters
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
cjk = 0
|
||||
for ch in text:
|
||||
if "\u4e00" <= ch <= "\u9fff":
|
||||
cjk += 1
|
||||
|
||||
non_cjk_len = len(text) - cjk
|
||||
return cjk + int(math.ceil(non_cjk_len / 4))
|
||||
|
||||
|
||||
def truncate_to_token_budget(text: str, token_budget: int) -> str:
|
||||
if not text or token_budget <= 0:
|
||||
return ""
|
||||
if estimate_tokens(text) <= token_budget:
|
||||
return text
|
||||
|
||||
lo, hi = 0, len(text)
|
||||
while lo < hi:
|
||||
mid = (lo + hi + 1) // 2
|
||||
if estimate_tokens(text[:mid]) <= token_budget:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid - 1
|
||||
|
||||
return text[:lo].rstrip()
|
||||
|
||||
|
||||
def fit_texts_to_token_budget(
|
||||
texts: List[str],
|
||||
token_budget: int,
|
||||
*,
|
||||
separator: str = "\n\n",
|
||||
) -> List[str]:
|
||||
if token_budget <= 0:
|
||||
return [text for text in texts if (text or "").strip()]
|
||||
|
||||
selected: List[str] = []
|
||||
used = 0
|
||||
sep_tokens = estimate_tokens(separator)
|
||||
|
||||
for text in texts:
|
||||
if not (text or "").strip():
|
||||
continue
|
||||
|
||||
add_sep = sep_tokens if selected else 0
|
||||
text_tokens = estimate_tokens(text)
|
||||
|
||||
if used + add_sep + text_tokens <= token_budget:
|
||||
if selected:
|
||||
used += sep_tokens
|
||||
selected.append(text)
|
||||
used += text_tokens
|
||||
continue
|
||||
|
||||
remaining = token_budget - used - add_sep
|
||||
if remaining <= 0:
|
||||
break
|
||||
|
||||
truncated = truncate_to_token_budget(text, remaining)
|
||||
if truncated:
|
||||
if selected:
|
||||
used += sep_tokens
|
||||
selected.append(truncated)
|
||||
break
|
||||
|
||||
return selected
|
||||
27
tests/unit/test_rag_context_utils.py
Normal file
27
tests/unit/test_rag_context_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from bidmaster.utils.rag_context import (
|
||||
estimate_tokens,
|
||||
fit_texts_to_token_budget,
|
||||
truncate_to_token_budget,
|
||||
)
|
||||
|
||||
|
||||
def test_estimate_tokens_cjk_vs_ascii():
|
||||
assert estimate_tokens("中文") >= 2
|
||||
assert estimate_tokens("abcd") >= 1
|
||||
|
||||
|
||||
def test_truncate_to_token_budget_truncates():
|
||||
text = "中文中文中文" # 6 CJK chars
|
||||
truncated = truncate_to_token_budget(text, 3)
|
||||
assert truncated
|
||||
assert estimate_tokens(truncated) <= 3
|
||||
|
||||
|
||||
def test_fit_texts_to_token_budget_drops_or_truncates_tail():
|
||||
first = "中文中文" # ~4 tokens
|
||||
second = "中文中文中文" # ~6 tokens
|
||||
budget = estimate_tokens(first)
|
||||
fitted = fit_texts_to_token_budget([first, second], budget)
|
||||
assert fitted == [first]
|
||||
142
tests/unit/test_rag_selection_filters.py
Normal file
142
tests/unit/test_rag_selection_filters.py
Normal file
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from bidmaster.nodes.base import NodeContext
|
||||
from bidmaster.nodes.content.generate_content import GenerateContentNode, settings as node_settings
|
||||
from bidmaster.utils.document_context import DocumentContextMatch
|
||||
from bidmaster.utils.rag_context import estimate_tokens
|
||||
|
||||
|
||||
class DummyTenderSearcher:
|
||||
def __init__(self, matches: list[DocumentContextMatch]):
|
||||
self.matches = matches
|
||||
self.last_query: str | None = None
|
||||
|
||||
def search(self, query: str, top_k: int): # pragma: no cover
|
||||
self.last_query = query
|
||||
return self.matches[:top_k]
|
||||
|
||||
|
||||
class DummyRagTool:
|
||||
def __init__(self, results: list[dict[str, Any]]):
|
||||
self.results = results
|
||||
self.last_context: Dict[str, Any] | None = None
|
||||
|
||||
def search(self, query: str, k: int): # pragma: no cover
|
||||
return self.results[:k]
|
||||
|
||||
def generate_content(self, task_id: str, context: Dict[str, Any]): # pragma: no cover
|
||||
self.last_context = context
|
||||
return "generated"
|
||||
|
||||
|
||||
def _base_state(chapter: Dict[str, Any], rag_tool: DummyRagTool) -> Dict[str, Any]:
|
||||
return {
|
||||
"current_chapter": chapter,
|
||||
"chapter_queue": [chapter],
|
||||
"chapter_configs": {},
|
||||
"chapter_children_map": {},
|
||||
"generated_contents": {},
|
||||
"rag_tool": rag_tool,
|
||||
}
|
||||
|
||||
|
||||
def test_kb_rag_similarity_threshold_filters_low_scores(monkeypatch):
|
||||
monkeypatch.setattr(node_settings, "rag_search_top_k_max", 10)
|
||||
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.8)
|
||||
monkeypatch.setattr(node_settings, "rag_context_dedup", False)
|
||||
monkeypatch.setattr(node_settings, "rag_context_token_budget", 0)
|
||||
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
rag_tool = DummyRagTool(
|
||||
results=[
|
||||
{"content": "LOW", "metadata": {"source": "kb"}, "score": 0.2},
|
||||
{"content": "HIGH", "metadata": {"source": "kb"}, "score": 0.95},
|
||||
]
|
||||
)
|
||||
state: Dict[str, Any] = _base_state(chapter, rag_tool)
|
||||
state["available_rag_sources"] = {"global_kb": {"available": True}}
|
||||
|
||||
node = GenerateContentNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
rag_context = rag_tool.last_context.get("rag_context")
|
||||
assert "HIGH" in rag_context
|
||||
assert "LOW" not in rag_context
|
||||
|
||||
|
||||
def test_rag_context_dedup_removes_duplicate_texts_across_sources(monkeypatch):
|
||||
monkeypatch.setattr(node_settings, "rag_context_dedup", True)
|
||||
monkeypatch.setattr(node_settings, "rag_context_token_budget", 0)
|
||||
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0)
|
||||
monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0)
|
||||
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
rag_tool = DummyRagTool(
|
||||
results=[
|
||||
{"content": "DUP", "metadata": {"source": "kb"}, "score": 0.9},
|
||||
]
|
||||
)
|
||||
state: Dict[str, Any] = _base_state(chapter, rag_tool)
|
||||
state["available_rag_sources"] = {
|
||||
"tender_doc": {"available": True},
|
||||
"global_kb": {"available": True},
|
||||
}
|
||||
state["tender_doc_searcher"] = DummyTenderSearcher(
|
||||
matches=[
|
||||
DocumentContextMatch(
|
||||
text="DUP",
|
||||
section="章节一",
|
||||
score=0.9,
|
||||
source_type="paragraph",
|
||||
metadata={},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
node = GenerateContentNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
rag_context = rag_tool.last_context.get("rag_context")
|
||||
assert rag_context.count("DUP") == 1
|
||||
|
||||
|
||||
def test_rag_context_token_budget_limits_total_context(monkeypatch):
|
||||
monkeypatch.setattr(node_settings, "rag_context_dedup", False)
|
||||
monkeypatch.setattr(node_settings, "rag_similarity_threshold", 0.0)
|
||||
monkeypatch.setattr(node_settings, "tender_doc_similarity_threshold", 0.0)
|
||||
|
||||
first = "【招标文件】章节一\nA"
|
||||
budget = estimate_tokens(first)
|
||||
monkeypatch.setattr(node_settings, "rag_context_token_budget", budget)
|
||||
|
||||
chapter = {"id": "chapter_1", "title": "概述", "level": 1}
|
||||
rag_tool = DummyRagTool(
|
||||
results=[
|
||||
{"content": "B", "metadata": {"source": "kb"}, "score": 0.9},
|
||||
]
|
||||
)
|
||||
state: Dict[str, Any] = _base_state(chapter, rag_tool)
|
||||
state["available_rag_sources"] = {
|
||||
"tender_doc": {"available": True},
|
||||
"global_kb": {"available": True},
|
||||
}
|
||||
state["tender_doc_searcher"] = DummyTenderSearcher(
|
||||
matches=[
|
||||
DocumentContextMatch(
|
||||
text="A",
|
||||
section="章节一",
|
||||
score=0.9,
|
||||
source_type="paragraph",
|
||||
metadata={},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
node = GenerateContentNode()
|
||||
node.execute(state, NodeContext())
|
||||
|
||||
rag_context = rag_tool.last_context.get("rag_context")
|
||||
assert "A" in rag_context
|
||||
assert "【知识库】" not in rag_context
|
||||
Loading…
Reference in New Issue
Block a user