Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
274 lines
9.1 KiB
Python
274 lines
9.1 KiB
Python
"""RAG检索增强生成工具
|
||
|
||
基于ChromaDB的文档检索系统,支持文档索引、相似度搜索和内容检索。
|
||
"""
|
||
|
||
import hashlib
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import chromadb
|
||
from chromadb.config import Settings as ChromaSettings
|
||
from chromadb.utils import embedding_functions
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.document_loaders import (
|
||
PyPDFLoader,
|
||
TextLoader,
|
||
UnstructuredWordDocumentLoader,
|
||
)
|
||
from langchain_core.documents import Document
|
||
|
||
from ..config import get_settings
|
||
from ..config.prompt_manager import get_prompt_manager
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RAGTool:
|
||
"""RAG工具类"""
|
||
|
||
def __init__(self) -> None:
|
||
self.settings = get_settings()
|
||
self.chroma_path = Path(self.settings.chroma_path)
|
||
self.chroma_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 初始化ChromaDB客户端
|
||
self.client = chromadb.PersistentClient(
|
||
path=str(self.chroma_path),
|
||
settings=ChromaSettings(anonymized_telemetry=False)
|
||
)
|
||
|
||
# 初始化嵌入函数
|
||
self.embedding_function = self._get_embedding_function()
|
||
|
||
# 获取或创建集合
|
||
try:
|
||
# 尝试获取已存在的集合
|
||
self.collection = self.client.get_collection(
|
||
name=self.settings.collection_name
|
||
)
|
||
except Exception:
|
||
# 集合不存在,创建新集合
|
||
self.collection = self.client.create_collection(
|
||
name=self.settings.collection_name,
|
||
embedding_function=self.embedding_function,
|
||
metadata={"description": "BidMaster知识库"}
|
||
)
|
||
|
||
# 初始化文本分割器
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=self.settings.chunk_size,
|
||
chunk_overlap=self.settings.chunk_overlap,
|
||
length_function=len,
|
||
)
|
||
|
||
def add_document(self, file_path: str):
|
||
"""添加文档到知识库"""
|
||
file_path_obj = Path(file_path)
|
||
|
||
if not file_path_obj.exists():
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
# 加载文档
|
||
documents = self._load_document(file_path_obj)
|
||
if not documents:
|
||
raise ValueError(f"未能从文件中提取内容: {file_path}")
|
||
|
||
# 分割文档
|
||
chunks = self.text_splitter.split_documents(documents)
|
||
|
||
# 添加到向量数据库
|
||
self._add_chunks_to_db(chunks, file_path)
|
||
|
||
return True
|
||
|
||
def search(self, query: str, k: int = 5) -> list[dict[str, Any]]:
|
||
"""搜索相关内容"""
|
||
results = self.collection.query(
|
||
query_texts=[query],
|
||
n_results=k,
|
||
include=["documents", "metadatas", "distances"]
|
||
)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
if results["documents"] and results["documents"][0]:
|
||
for i, doc in enumerate(results["documents"][0]):
|
||
result = {
|
||
"content": doc,
|
||
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
||
"score": 1 - results["distances"][0][i] if results["distances"] else 0.0
|
||
}
|
||
formatted_results.append(result)
|
||
|
||
return formatted_results
|
||
|
||
def get_stats(self) -> dict[str, Any]:
|
||
"""获取知识库统计信息"""
|
||
count = self.collection.count()
|
||
files = set()
|
||
|
||
# 获取所有文档的文件路径
|
||
if count > 0:
|
||
all_data = self.collection.get(include=["metadatas"])
|
||
for metadata in all_data["metadatas"]:
|
||
if "source" in metadata:
|
||
files.add(metadata["source"])
|
||
|
||
return {
|
||
"total_chunks": count,
|
||
"total_files": len(files),
|
||
"files": list(files)
|
||
}
|
||
|
||
def reset_database(self):
|
||
"""重置数据库"""
|
||
# 删除集合
|
||
self.client.delete_collection(name=self.settings.collection_name)
|
||
|
||
# 重新创建集合
|
||
self.collection = self.client.get_or_create_collection(
|
||
name=self.settings.collection_name,
|
||
embedding_function=self.embedding_function,
|
||
metadata={"description": "BidMaster知识库"}
|
||
)
|
||
|
||
return True
|
||
|
||
def _load_document(self, file_path: Path) -> list[Document]:
|
||
"""根据文件类型加载文档"""
|
||
suffix = file_path.suffix.lower()
|
||
|
||
loaders = {
|
||
".pdf": PyPDFLoader,
|
||
".txt": TextLoader,
|
||
".md": TextLoader,
|
||
".docx": UnstructuredWordDocumentLoader,
|
||
}
|
||
|
||
loader_class = loaders.get(suffix)
|
||
if not loader_class:
|
||
raise ValueError(f"不支持的文件格式: {suffix}")
|
||
|
||
# 使用encoding参数处理文本文件
|
||
if suffix in [".txt", ".md"]:
|
||
loader = loader_class(str(file_path), encoding="utf-8")
|
||
else:
|
||
loader = loader_class(str(file_path))
|
||
|
||
return loader.load()
|
||
|
||
def _add_chunks_to_db(self, chunks: list[Document], source_file: str) -> None:
|
||
"""将文档块添加到数据库"""
|
||
if not chunks:
|
||
return
|
||
|
||
documents = []
|
||
metadatas = []
|
||
ids = []
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
# 生成唯一ID
|
||
chunk_id = self._generate_chunk_id(source_file, i, chunk.page_content)
|
||
|
||
documents.append(chunk.page_content)
|
||
metadatas.append({
|
||
"source": source_file,
|
||
"chunk_index": i,
|
||
"chunk_size": len(chunk.page_content),
|
||
**chunk.metadata
|
||
})
|
||
ids.append(chunk_id)
|
||
|
||
# 批量添加到ChromaDB
|
||
self.collection.add(
|
||
documents=documents,
|
||
metadatas=metadatas,
|
||
ids=ids
|
||
)
|
||
|
||
def _generate_chunk_id(self, source_file: str, chunk_index: int, content: str) -> str:
|
||
"""生成块的唯一ID"""
|
||
content_hash = hashlib.md5(content.encode()).hexdigest()[:8]
|
||
return f"{Path(source_file).stem}_{chunk_index}_{content_hash}"
|
||
|
||
def generate_content(self, task_id: str, context: dict, **kwargs) -> str:
|
||
"""生成内容
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
context: 上下文信息
|
||
**kwargs: 其他生成参数
|
||
|
||
Returns:
|
||
生成的内容
|
||
"""
|
||
from openai import OpenAI
|
||
|
||
# 从上下文中提取任务信息
|
||
task_title = context.get('title', '任务')
|
||
emphasis = context.get('emphasis', '')
|
||
rag_context = context.get('rag_context', '')
|
||
|
||
# 构建提示词变量
|
||
emphasis_part = f'\n特别强调:{emphasis}' if emphasis else ''
|
||
rag_part = f'\n\n参考资料:\n{rag_context}' if rag_context else ''
|
||
guidance_notes = (context.get('guidance_notes') or '').strip()
|
||
guidance_part = guidance_notes or '(暂无重点提示)'
|
||
|
||
prompt_variables = {
|
||
"title": task_title,
|
||
"chapter_path": context.get('chapter_path', task_title),
|
||
"score_info": context.get('score_info', '目标得分:未明确'),
|
||
"requirements_summary": context.get('requirements_summary', ''),
|
||
"rubric_points": context.get('rubric_points', '- 无明确评分要点'),
|
||
"objectives": context.get('objectives', '1. 围绕章节主题输出详实内容'),
|
||
"consistency_rules": context.get('consistency_rules', '1. 保持章节语气与格式一致'),
|
||
"context_summary": context.get('context_summary', '(暂无可引用的上下文)'),
|
||
"emphasis_part": emphasis_part,
|
||
"guidance_part": guidance_part,
|
||
"rag_part": rag_part,
|
||
}
|
||
|
||
# 从配置获取提示词
|
||
prompt_manager = get_prompt_manager()
|
||
prompt = prompt_manager.get_content_prompt("generate_with_rag", **prompt_variables)
|
||
|
||
# 调用LLM生成
|
||
client = OpenAI(
|
||
api_key=self.settings.api_key,
|
||
base_url=self.settings.base_url,
|
||
timeout=180,
|
||
max_retries=2,
|
||
)
|
||
|
||
response = client.chat.completions.create(
|
||
model=self.settings.model_name,
|
||
messages=[
|
||
{"role": "system", "content": prompt_manager.get_system_message("rag_generator")},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=2000
|
||
)
|
||
|
||
generated_content = response.choices[0].message.content.strip()
|
||
|
||
return generated_content
|
||
|
||
def _get_embedding_function(self):
|
||
"""获取嵌入函数"""
|
||
embedding_model = self.settings.embedding_model
|
||
|
||
if embedding_model.startswith("text-embedding-"):
|
||
# OpenAI嵌入模型
|
||
return embedding_functions.OpenAIEmbeddingFunction(
|
||
api_key=self.settings.api_key,
|
||
model_name=embedding_model
|
||
)
|
||
else:
|
||
# 本地sentence-transformers模型
|
||
return embedding_functions.SentenceTransformerEmbeddingFunction(
|
||
model_name=embedding_model
|
||
) |