diff --git a/src/bidmaster/tools/rag.py b/src/bidmaster/tools/rag.py index e4962fc..adb67ac 100644 --- a/src/bidmaster/tools/rag.py +++ b/src/bidmaster/tools/rag.py @@ -43,11 +43,18 @@ class RAGTool: self.embedding_function = self._get_embedding_function() # 获取或创建集合 - self.collection = self.client.get_or_create_collection( - name=self.settings.collection_name, - embedding_function=self.embedding_function, - metadata={"description": "BidMaster知识库"} - ) + try: + # 尝试获取已存在的集合 + self.collection = self.client.get_collection( + name=self.settings.collection_name + ) + except Exception: + # 集合不存在,创建新集合 + self.collection = self.client.create_collection( + name=self.settings.collection_name, + embedding_function=self.embedding_function, + metadata={"description": "BidMaster知识库"} + ) # 初始化文本分割器 self.text_splitter = RecursiveCharacterTextSplitter( @@ -74,6 +81,8 @@ class RAGTool: # 添加到向量数据库 self._add_chunks_to_db(chunks, file_path) + return True + def search(self, query: str, k: int = 5) -> list[dict[str, Any]]: """搜索相关内容""" results = self.collection.query( @@ -121,9 +130,12 @@ class RAGTool: # 重新创建集合 self.collection = self.client.get_or_create_collection( name=self.settings.collection_name, + embedding_function=self.embedding_function, metadata={"description": "BidMaster知识库"} ) + return True + def _load_document(self, file_path: Path) -> list[Document]: """根据文件类型加载文档""" suffix = file_path.suffix.lower()