bidmaster-cli/src/bidmaster/tools/rag.py
sladro c1292fcacc feat: add validation and toc pipeline upgrades
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2025-11-19 10:11:21 +08:00

274 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RAG检索增强生成工具
基于ChromaDB的文档检索系统支持文档索引、相似度搜索和内容检索。
"""
import hashlib
import logging
from pathlib import Path
from typing import Any
import chromadb
from chromadb.config import Settings as ChromaSettings
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
UnstructuredWordDocumentLoader,
)
from langchain_core.documents import Document
from ..config import get_settings
from ..config.prompt_manager import get_prompt_manager
logger = logging.getLogger(__name__)
class RAGTool:
"""RAG工具类"""
def __init__(self) -> None:
self.settings = get_settings()
self.chroma_path = Path(self.settings.chroma_path)
self.chroma_path.mkdir(parents=True, exist_ok=True)
# 初始化ChromaDB客户端
self.client = chromadb.PersistentClient(
path=str(self.chroma_path),
settings=ChromaSettings(anonymized_telemetry=False)
)
# 初始化嵌入函数
self.embedding_function = self._get_embedding_function()
# 获取或创建集合
try:
# 尝试获取已存在的集合
self.collection = self.client.get_collection(
name=self.settings.collection_name
)
except Exception:
# 集合不存在,创建新集合
self.collection = self.client.create_collection(
name=self.settings.collection_name,
embedding_function=self.embedding_function,
metadata={"description": "BidMaster知识库"}
)
# 初始化文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.settings.chunk_size,
chunk_overlap=self.settings.chunk_overlap,
length_function=len,
)
def add_document(self, file_path: str):
"""添加文档到知识库"""
file_path_obj = Path(file_path)
if not file_path_obj.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
# 加载文档
documents = self._load_document(file_path_obj)
if not documents:
raise ValueError(f"未能从文件中提取内容: {file_path}")
# 分割文档
chunks = self.text_splitter.split_documents(documents)
# 添加到向量数据库
self._add_chunks_to_db(chunks, file_path)
return True
def search(self, query: str, k: int = 5) -> list[dict[str, Any]]:
"""搜索相关内容"""
results = self.collection.query(
query_texts=[query],
n_results=k,
include=["documents", "metadatas", "distances"]
)
# 格式化结果
formatted_results = []
if results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
result = {
"content": doc,
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"score": 1 - results["distances"][0][i] if results["distances"] else 0.0
}
formatted_results.append(result)
return formatted_results
def get_stats(self) -> dict[str, Any]:
"""获取知识库统计信息"""
count = self.collection.count()
files = set()
# 获取所有文档的文件路径
if count > 0:
all_data = self.collection.get(include=["metadatas"])
for metadata in all_data["metadatas"]:
if "source" in metadata:
files.add(metadata["source"])
return {
"total_chunks": count,
"total_files": len(files),
"files": list(files)
}
def reset_database(self):
"""重置数据库"""
# 删除集合
self.client.delete_collection(name=self.settings.collection_name)
# 重新创建集合
self.collection = self.client.get_or_create_collection(
name=self.settings.collection_name,
embedding_function=self.embedding_function,
metadata={"description": "BidMaster知识库"}
)
return True
def _load_document(self, file_path: Path) -> list[Document]:
"""根据文件类型加载文档"""
suffix = file_path.suffix.lower()
loaders = {
".pdf": PyPDFLoader,
".txt": TextLoader,
".md": TextLoader,
".docx": UnstructuredWordDocumentLoader,
}
loader_class = loaders.get(suffix)
if not loader_class:
raise ValueError(f"不支持的文件格式: {suffix}")
# 使用encoding参数处理文本文件
if suffix in [".txt", ".md"]:
loader = loader_class(str(file_path), encoding="utf-8")
else:
loader = loader_class(str(file_path))
return loader.load()
def _add_chunks_to_db(self, chunks: list[Document], source_file: str) -> None:
"""将文档块添加到数据库"""
if not chunks:
return
documents = []
metadatas = []
ids = []
for i, chunk in enumerate(chunks):
# 生成唯一ID
chunk_id = self._generate_chunk_id(source_file, i, chunk.page_content)
documents.append(chunk.page_content)
metadatas.append({
"source": source_file,
"chunk_index": i,
"chunk_size": len(chunk.page_content),
**chunk.metadata
})
ids.append(chunk_id)
# 批量添加到ChromaDB
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
def _generate_chunk_id(self, source_file: str, chunk_index: int, content: str) -> str:
"""生成块的唯一ID"""
content_hash = hashlib.md5(content.encode()).hexdigest()[:8]
return f"{Path(source_file).stem}_{chunk_index}_{content_hash}"
def generate_content(self, task_id: str, context: dict, **kwargs) -> str:
"""生成内容
Args:
task_id: 任务ID
context: 上下文信息
**kwargs: 其他生成参数
Returns:
生成的内容
"""
from openai import OpenAI
# 从上下文中提取任务信息
task_title = context.get('title', '任务')
emphasis = context.get('emphasis', '')
rag_context = context.get('rag_context', '')
# 构建提示词变量
emphasis_part = f'\n特别强调:{emphasis}' if emphasis else ''
rag_part = f'\n\n参考资料:\n{rag_context}' if rag_context else ''
guidance_notes = (context.get('guidance_notes') or '').strip()
guidance_part = guidance_notes or '(暂无重点提示)'
prompt_variables = {
"title": task_title,
"chapter_path": context.get('chapter_path', task_title),
"score_info": context.get('score_info', '目标得分:未明确'),
"requirements_summary": context.get('requirements_summary', ''),
"rubric_points": context.get('rubric_points', '- 无明确评分要点'),
"objectives": context.get('objectives', '1. 围绕章节主题输出详实内容'),
"consistency_rules": context.get('consistency_rules', '1. 保持章节语气与格式一致'),
"context_summary": context.get('context_summary', '(暂无可引用的上下文)'),
"emphasis_part": emphasis_part,
"guidance_part": guidance_part,
"rag_part": rag_part,
}
# 从配置获取提示词
prompt_manager = get_prompt_manager()
prompt = prompt_manager.get_content_prompt("generate_with_rag", **prompt_variables)
# 调用LLM生成
client = OpenAI(
api_key=self.settings.api_key,
base_url=self.settings.base_url,
timeout=180,
max_retries=2,
)
response = client.chat.completions.create(
model=self.settings.model_name,
messages=[
{"role": "system", "content": prompt_manager.get_system_message("rag_generator")},
{"role": "user", "content": prompt}
],
temperature=0.7,
max_tokens=2000
)
generated_content = response.choices[0].message.content.strip()
return generated_content
def _get_embedding_function(self):
"""获取嵌入函数"""
embedding_model = self.settings.embedding_model
if embedding_model.startswith("text-embedding-"):
# OpenAI嵌入模型
return embedding_functions.OpenAIEmbeddingFunction(
api_key=self.settings.api_key,
model_name=embedding_model
)
else:
# 本地sentence-transformers模型
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=embedding_model
)