1160 lines
39 KiB
Python
1160 lines
39 KiB
Python
import requests
|
||
import json
|
||
from typing import Optional, List, Dict, Any, Union, Generator
|
||
import os
|
||
|
||
|
||
class RAGFlowError(Exception):
|
||
"""RAGFlow API错误异常"""
|
||
def __init__(self, code: int, message: str):
|
||
self.code = code
|
||
self.message = message
|
||
super().__init__(f"Error {code}: {message}")
|
||
|
||
|
||
class RAGFlowClient:
|
||
"""RAGFlow API客户端"""
|
||
|
||
def __init__(self, base_url: str, api_key: str):
|
||
"""
|
||
初始化RAGFlow客户端
|
||
|
||
Args:
|
||
base_url: RAGFlow服务器地址
|
||
api_key: API密钥
|
||
"""
|
||
self.base_url = base_url.rstrip('/')
|
||
self.api_key = api_key
|
||
self.headers = {
|
||
'Authorization': f'Bearer {api_key}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
def _request(self, method: str, endpoint: str, **kwargs) -> Dict[str, Any]:
|
||
"""发送HTTP请求"""
|
||
url = f"{self.base_url}{endpoint}"
|
||
|
||
# 处理headers
|
||
headers = kwargs.pop('headers', self.headers.copy())
|
||
|
||
response = requests.request(method, url, headers=headers, **kwargs)
|
||
|
||
try:
|
||
result = response.json()
|
||
except json.JSONDecodeError:
|
||
if response.status_code == 200:
|
||
return {'code': 0, 'data': response.content}
|
||
else:
|
||
raise RAGFlowError(response.status_code, response.text)
|
||
|
||
if result.get('code', 0) != 0:
|
||
raise RAGFlowError(result.get('code'), result.get('message', 'Unknown error'))
|
||
|
||
return result
|
||
|
||
def _stream_request(self, method: str, endpoint: str, **kwargs) -> Generator[Dict[str, Any], None, None]:
|
||
"""发送流式HTTP请求,修复SSE缓冲问题"""
|
||
url = f"{self.base_url}{endpoint}"
|
||
headers = kwargs.pop('headers', self.headers.copy())
|
||
|
||
# 添加防止缓冲的响应头
|
||
headers.update({
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
'X-Accel-Buffering': 'no' # Nginx禁用缓冲
|
||
})
|
||
|
||
response = requests.request(method, url, headers=headers, stream=True, **kwargs)
|
||
|
||
# 立即处理响应,yield每个chunk
|
||
for line in response.iter_lines():
|
||
if line:
|
||
line = line.decode('utf-8')
|
||
if line.startswith('data:'):
|
||
try:
|
||
data = json.loads(line[5:].strip())
|
||
yield data
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
# ====================
|
||
# OpenAI兼容API
|
||
# ====================
|
||
|
||
def create_chat_completion(self, chat_id: str, model: str, messages: List[Dict[str, str]],
|
||
stream: bool = False) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
|
||
"""
|
||
创建聊天完成
|
||
|
||
Args:
|
||
chat_id: 聊天ID
|
||
model: 模型名称
|
||
messages: 消息列表
|
||
stream: 是否流式返回
|
||
"""
|
||
endpoint = f"/api/v1/chats_openai/{chat_id}/chat/completions"
|
||
data = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream
|
||
}
|
||
|
||
if stream:
|
||
return self._stream_request('POST', endpoint, json=data)
|
||
else:
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def create_agent_completion(self, agent_id: str, model: str, messages: List[Dict[str, str]],
|
||
stream: bool = False) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
|
||
"""
|
||
创建代理完成
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
model: 模型名称
|
||
messages: 消息列表
|
||
stream: 是否流式返回
|
||
"""
|
||
endpoint = f"/api/v1/agents_openai/{agent_id}/chat/completions"
|
||
data = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"stream": stream
|
||
}
|
||
|
||
if stream:
|
||
return self._stream_request('POST', endpoint, json=data)
|
||
else:
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
# ====================
|
||
# 数据集管理
|
||
# ====================
|
||
|
||
def create_dataset(self, name: str, avatar: Optional[str] = None, description: Optional[str] = None,
|
||
embedding_model: Optional[str] = None, permission: str = "me",
|
||
chunk_method: str = "naive",
|
||
parser_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
创建数据集
|
||
|
||
Args:
|
||
name: 数据集名称
|
||
avatar: Base64编码的头像
|
||
description: 描述
|
||
embedding_model: 嵌入模型
|
||
permission: 权限设置 ("me" 或 "team")
|
||
chunk_method: 分块方法
|
||
"naive": General (default)
|
||
parser_config
|
||
"auto_keywords": int
|
||
Defaults to 0
|
||
Minimum: 0
|
||
Maximum: 32
|
||
"auto_questions": int
|
||
Defaults to 0
|
||
Minimum: 0
|
||
Maximum: 10
|
||
"chunk_token_num": int
|
||
Defaults to 512
|
||
Minimum: 1
|
||
Maximum: 2048
|
||
"delimiter": string
|
||
Defaults to "\n".
|
||
"html4excel": bool Indicates whether to convert Excel documents into HTML format.
|
||
Defaults to false
|
||
"layout_recognize": string
|
||
Defaults to DeepDOC
|
||
"tag_kb_ids": array<string> refer to Use tag set
|
||
Must include a list of dataset IDs, where each dataset is parsed using the Tag Chunking Method
|
||
"task_page_size": int For PDF only.
|
||
Defaults to 12
|
||
Minimum: 1
|
||
"raptor": object RAPTOR-specific settings.
|
||
Defaults to: {"use_raptor": false}
|
||
"graphrag": object GRAPHRAG-specific settings.
|
||
Defaults to: {"use_graphrag": false}
|
||
"book": Book
|
||
"email": Email
|
||
"laws": Laws
|
||
"manual": Manual
|
||
"one": One
|
||
"paper": Paper
|
||
"picture": Picture
|
||
"presentation": Presentation
|
||
"qa": Q&A
|
||
"table": Table
|
||
"tag": Tag
|
||
pagerank: 页面排名
|
||
parser_config: 解析器配置
|
||
"""
|
||
endpoint = "/api/v1/datasets"
|
||
data = {
|
||
"name": name,
|
||
"permission": permission,
|
||
"chunk_method": chunk_method,
|
||
# "pagerank": pagerank
|
||
}
|
||
|
||
if avatar:
|
||
data["avatar"] = avatar
|
||
if description:
|
||
data["description"] = description
|
||
if embedding_model:
|
||
data["embedding_model"] = embedding_model
|
||
if parser_config:
|
||
data["parser_config"] = parser_config
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def delete_datasets(self, ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除数据集
|
||
|
||
Args:
|
||
ids: 要删除的数据集ID列表,None表示删除所有
|
||
"""
|
||
endpoint = "/api/v1/datasets"
|
||
data = {"ids": ids}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def update_dataset(self, dataset_id: str, name: Optional[str] = None,
|
||
avatar: Optional[str] = None, description: Optional[str] = None,
|
||
embedding_model: Optional[str] = None, permission: Optional[str] = None,
|
||
chunk_method: Optional[str] = None, pagerank: Optional[int] = None,
|
||
parser_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
name: 新名称
|
||
avatar: 新头像
|
||
description: 新描述
|
||
embedding_model: 新嵌入模型
|
||
permission: 新权限设置
|
||
chunk_method: 新分块方法
|
||
pagerank: 新页面排名
|
||
parser_config: 新解析器配置
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}"
|
||
data = {}
|
||
|
||
if name is not None:
|
||
data["name"] = name
|
||
if avatar is not None:
|
||
data["avatar"] = avatar
|
||
if description is not None:
|
||
data["description"] = description
|
||
if embedding_model is not None:
|
||
data["embedding_model"] = embedding_model
|
||
if permission is not None:
|
||
data["permission"] = permission
|
||
if chunk_method is not None:
|
||
data["chunk_method"] = chunk_method
|
||
if pagerank is not None:
|
||
data["pagerank"] = pagerank
|
||
if parser_config is not None:
|
||
data["parser_config"] = parser_config
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def list_datasets(self, page: int = 1, page_size: int = 30, orderby: str = "create_time",
|
||
desc: bool = True, name: Optional[str] = None,
|
||
dataset_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出数据集
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
name: 筛选名称
|
||
dataset_id: 筛选ID
|
||
"""
|
||
endpoint = "/api/v1/datasets"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc
|
||
}
|
||
|
||
if name:
|
||
params["name"] = name
|
||
if dataset_id:
|
||
params["id"] = dataset_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
# ====================
|
||
# 文档管理
|
||
# ====================
|
||
|
||
def upload_documents(self, dataset_id: str, file_paths: List[str]) -> Dict[str, Any]:
|
||
"""
|
||
上传文档到数据集
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
file_paths: 文件路径列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents"
|
||
|
||
files = []
|
||
for file_path in file_paths:
|
||
if os.path.exists(file_path):
|
||
files.append(('file', open(file_path, 'rb')))
|
||
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
}
|
||
|
||
try:
|
||
response = requests.post(f"{self.base_url}{endpoint}",
|
||
headers=headers, files=files)
|
||
result = response.json()
|
||
|
||
if result.get('code', 0) != 0:
|
||
raise RAGFlowError(result.get('code'), result.get('message'))
|
||
|
||
return result
|
||
finally:
|
||
for _, file_obj in files:
|
||
file_obj.close()
|
||
|
||
def update_document(self, dataset_id: str, document_id: str, name: Optional[str] = None,
|
||
meta_fields: Optional[Dict[str, Any]] = None,
|
||
chunk_method: Optional[str] = None,
|
||
parser_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新文档配置
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
name: 新名称
|
||
meta_fields: 元字段
|
||
chunk_method: 分块方法
|
||
parser_config: 解析器配置
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}"
|
||
data = {}
|
||
|
||
if name is not None:
|
||
data["name"] = name
|
||
if meta_fields is not None:
|
||
data["meta_fields"] = meta_fields
|
||
if chunk_method is not None:
|
||
data["chunk_method"] = chunk_method
|
||
if parser_config is not None:
|
||
data["parser_config"] = parser_config
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def download_document(self, dataset_id: str, document_id: str, save_path: str) -> None:
|
||
"""
|
||
下载文档
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
save_path: 保存路径
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}"
|
||
headers = {'Authorization': f'Bearer {self.api_key}'}
|
||
|
||
response = requests.get(f"{self.base_url}{endpoint}", headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
with open(save_path, 'wb') as f:
|
||
f.write(response.content)
|
||
else:
|
||
try:
|
||
error = response.json()
|
||
raise RAGFlowError(error.get('code'), error.get('message'))
|
||
except json.JSONDecodeError:
|
||
raise RAGFlowError(response.status_code, response.text)
|
||
|
||
def list_documents(self, dataset_id: str, page: int = 1, page_size: int = 30,
|
||
orderby: str = "create_time", desc: bool = True,
|
||
keywords: Optional[str] = None, document_id: Optional[str] = None,
|
||
document_name: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出数据集中的文档
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
keywords: 关键词搜索
|
||
document_id: 筛选文档ID
|
||
document_name: 筛选文档名称
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc
|
||
}
|
||
|
||
if keywords:
|
||
params["keywords"] = keywords
|
||
if document_id:
|
||
params["id"] = document_id
|
||
if document_name:
|
||
params["name"] = document_name
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
def delete_documents(self, dataset_id: str, ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除文档
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
ids: 要删除的文档ID列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents"
|
||
data = {"ids": ids} if ids else {}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def parse_documents(self, dataset_id: str, document_ids: List[str]) -> Dict[str, Any]:
|
||
"""
|
||
解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_ids: 要解析的文档ID列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/chunks"
|
||
data = {"document_ids": document_ids}
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def stop_parsing_documents(self, dataset_id: str, document_ids: List[str]) -> Dict[str, Any]:
|
||
"""
|
||
停止解析文档
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_ids: 要停止解析的文档ID列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/chunks"
|
||
data = {"document_ids": document_ids}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
# ====================
|
||
# 分块管理
|
||
# ====================
|
||
|
||
def add_chunk(self, dataset_id: str, document_id: str, content: str,
|
||
important_keywords: Optional[List[str]] = None,
|
||
questions: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
添加分块
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
content: 分块内容
|
||
important_keywords: 重要关键词
|
||
questions: 问题列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||
data = {"content": content}
|
||
|
||
if important_keywords:
|
||
data["important_keywords"] = important_keywords
|
||
if questions:
|
||
data["questions"] = questions
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def list_chunks(self, dataset_id: str, document_id: str, keywords: Optional[str] = None,
|
||
page: int = 1, page_size: int = 1024,
|
||
chunk_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出分块
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
keywords: 关键词搜索
|
||
page: 页码
|
||
page_size: 每页数量
|
||
chunk_id: 分块ID筛选
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||
params = {"page": page, "page_size": page_size}
|
||
|
||
if keywords:
|
||
params["keywords"] = keywords
|
||
if chunk_id:
|
||
params["id"] = chunk_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
def delete_chunks(self, dataset_id: str, document_id: str,
|
||
chunk_ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除分块
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
chunk_ids: 要删除的分块ID列表
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks"
|
||
data = {"chunk_ids": chunk_ids} if chunk_ids else {}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def update_chunk(self, dataset_id: str, document_id: str, chunk_id: str,
|
||
content: Optional[str] = None,
|
||
important_keywords: Optional[List[str]] = None,
|
||
available: Optional[bool] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新分块
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_id: 文档ID
|
||
chunk_id: 分块ID
|
||
content: 新内容
|
||
important_keywords: 重要关键词
|
||
available: 可用状态
|
||
"""
|
||
endpoint = f"/api/v1/datasets/{dataset_id}/documents/{document_id}/chunks/{chunk_id}"
|
||
data = {}
|
||
|
||
if content is not None:
|
||
data["content"] = content
|
||
if important_keywords is not None:
|
||
data["important_keywords"] = important_keywords
|
||
if available is not None:
|
||
data["available"] = available
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def retrieve_chunks(self, question: str, dataset_ids: Optional[List[str]] = None,
|
||
document_ids: Optional[List[str]] = None, page: int = 1,
|
||
page_size: int = 30, similarity_threshold: float = 0.2,
|
||
vector_similarity_weight: float = 0.3, top_k: int = 1024,
|
||
rerank_id: Optional[str] = None, keyword: bool = False,
|
||
highlight: bool = False) -> Dict[str, Any]:
|
||
"""
|
||
检索分块
|
||
|
||
Args:
|
||
question: 查询问题
|
||
dataset_ids: 数据集ID列表
|
||
document_ids: 文档ID列表
|
||
page: 页码
|
||
page_size: 每页数量
|
||
similarity_threshold: 相似度阈值
|
||
vector_similarity_weight: 向量相似度权重
|
||
top_k: TopK数量
|
||
rerank_id: 重排序模型ID
|
||
keyword: 是否启用关键词匹配
|
||
highlight: 是否高亮显示
|
||
"""
|
||
endpoint = "/api/v1/retrieval"
|
||
data = {
|
||
"question": question,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"similarity_threshold": similarity_threshold,
|
||
"vector_similarity_weight": vector_similarity_weight,
|
||
"top_k": top_k,
|
||
"keyword": keyword,
|
||
"highlight": highlight
|
||
}
|
||
|
||
if dataset_ids:
|
||
data["dataset_ids"] = dataset_ids
|
||
if document_ids:
|
||
data["document_ids"] = document_ids
|
||
if rerank_id:
|
||
data["rerank_id"] = rerank_id
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
# ====================
|
||
# 聊天助手管理
|
||
# ====================
|
||
|
||
def create_chat_assistant(self, name: str, avatar: Optional[str] = None,
|
||
dataset_ids: Optional[List[str]] = None,
|
||
llm: Optional[Dict[str, Any]] = None,
|
||
prompt: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
创建聊天助手
|
||
|
||
Args:
|
||
name: 助手名称
|
||
avatar: Base64编码头像
|
||
dataset_ids: 关联的数据集ID列表
|
||
llm: LLM配置
|
||
prompt: 提示配置
|
||
"""
|
||
endpoint = "/api/v1/chats"
|
||
data = {"name": name}
|
||
|
||
if avatar:
|
||
data["avatar"] = avatar
|
||
if dataset_ids:
|
||
data["dataset_ids"] = dataset_ids
|
||
if llm:
|
||
data["llm"] = llm
|
||
if prompt:
|
||
data["prompt"] = prompt
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def update_chat_assistant(self, chat_id: str, name: Optional[str] = None,
|
||
avatar: Optional[str] = None,
|
||
dataset_ids: Optional[List[str]] = None,
|
||
llm: Optional[Dict[str, Any]] = None,
|
||
prompt: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新聊天助手
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
name: 新名称
|
||
avatar: 新头像
|
||
dataset_ids: 新数据集ID列表
|
||
llm: 新LLM配置
|
||
prompt: 新提示配置
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}"
|
||
data = {}
|
||
|
||
if name is not None:
|
||
data["name"] = name
|
||
if avatar is not None:
|
||
data["avatar"] = avatar
|
||
if dataset_ids is not None:
|
||
data["dataset_ids"] = dataset_ids
|
||
if llm is not None:
|
||
data["llm"] = llm
|
||
if prompt is not None:
|
||
data["prompt"] = prompt
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def delete_chat_assistants(self, ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除聊天助手
|
||
|
||
Args:
|
||
ids: 要删除的聊天助手ID列表
|
||
"""
|
||
endpoint = "/api/v1/chats"
|
||
data = {"ids": ids} if ids else {}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def list_chat_assistants(self, page: int = 1, page_size: int = 30,
|
||
orderby: str = "create_time", desc: bool = True,
|
||
name: Optional[str] = None,
|
||
chat_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出聊天助手
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
name: 筛选名称
|
||
chat_id: 筛选ID
|
||
"""
|
||
endpoint = "/api/v1/chats"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc
|
||
}
|
||
|
||
if name:
|
||
params["name"] = name
|
||
if chat_id:
|
||
params["id"] = chat_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
# ====================
|
||
# 会话管理
|
||
# ====================
|
||
|
||
def create_session_with_chat(self, chat_id: str, name: str,
|
||
user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
创建与聊天助手的会话
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
name: 会话名称
|
||
user_id: 可选的用户定义ID
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}/sessions"
|
||
data = {"name": name}
|
||
|
||
if user_id:
|
||
data["user_id"] = user_id
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def update_chat_session(self, chat_id: str, session_id: str, name: Optional[str] = None,
|
||
user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新聊天会话
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
session_id: 会话ID
|
||
name: 新名称
|
||
user_id: 新用户ID
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}/sessions/{session_id}"
|
||
data = {}
|
||
|
||
if name is not None:
|
||
data["name"] = name
|
||
if user_id is not None:
|
||
data["user_id"] = user_id
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def list_chat_sessions(self, chat_id: str, page: int = 1, page_size: int = 30,
|
||
orderby: str = "create_time", desc: bool = True,
|
||
name: Optional[str] = None, session_id: Optional[str] = None,
|
||
user_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出与指定聊天助手相关的聊天会话
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
name: 筛选名称
|
||
session_id: 筛选会话ID
|
||
user_id: 筛选用户ID
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}/sessions"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc
|
||
}
|
||
|
||
if name:
|
||
params["name"] = name
|
||
if session_id:
|
||
params["id"] = session_id
|
||
if user_id:
|
||
params["user_id"] = user_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
def delete_chat_sessions(self, chat_id: str, ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除聊天会话
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
ids: 要删除的会话ID列表
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}/sessions"
|
||
data = {"ids": ids} if ids else {}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def converse_with_chat_assistant(self, chat_id: str, question: str, stream: bool = True,
|
||
session_id: Optional[str] = None,
|
||
user_id: Optional[str] = None) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
|
||
"""
|
||
与聊天助手对话
|
||
|
||
Args:
|
||
chat_id: 聊天助手ID
|
||
question: 问题
|
||
stream: 是否流式返回
|
||
session_id: 会话ID
|
||
user_id: 用户ID
|
||
"""
|
||
endpoint = f"/api/v1/chats/{chat_id}/completions"
|
||
data = {"question": question, "stream": stream}
|
||
|
||
if session_id:
|
||
data["session_id"] = session_id
|
||
if user_id:
|
||
data["user_id"] = user_id
|
||
|
||
if stream:
|
||
return self._stream_request('POST', endpoint, json=data)
|
||
else:
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
# ====================
|
||
# 代理管理
|
||
# ====================
|
||
|
||
def create_session_with_agent(self, agent_id: str, user_id: Optional[str] = None,
|
||
file_data: Optional[Dict[str, Any]] = None,
|
||
**kwargs) -> Dict[str, Any]:
|
||
"""
|
||
创建与代理的会话
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
user_id: 用户ID
|
||
file_data: 文件数据(当Begin组件需要文件参数时)
|
||
**kwargs: 其他Begin组件需要的参数
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}/sessions"
|
||
params = {}
|
||
|
||
if user_id:
|
||
params["user_id"] = user_id
|
||
|
||
if file_data:
|
||
# 处理文件上传
|
||
headers = {'Authorization': f'Bearer {self.api_key}'}
|
||
files = {}
|
||
|
||
for key, file_path in file_data.items():
|
||
if os.path.exists(file_path):
|
||
files[key] = open(file_path, 'rb')
|
||
|
||
try:
|
||
response = requests.post(f"{self.base_url}{endpoint}",
|
||
headers=headers, files=files, params=params)
|
||
result = response.json()
|
||
|
||
if result.get('code', 0) != 0:
|
||
raise RAGFlowError(result.get('code'), result.get('message'))
|
||
|
||
return result
|
||
finally:
|
||
for file_obj in files.values():
|
||
file_obj.close()
|
||
else:
|
||
# 普通JSON请求
|
||
data = kwargs
|
||
return self._request('POST', endpoint, json=data, params=params)
|
||
|
||
def converse_with_agent(self, agent_id: str, question: str, stream: bool = True,
|
||
session_id: Optional[str] = None, user_id: Optional[str] = None,
|
||
sync_dsl: bool = False, **kwargs) -> Union[Dict[str, Any], Generator[Dict[str, Any], None, None]]:
|
||
"""
|
||
与代理对话
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
question: 问题
|
||
stream: 是否流式返回
|
||
session_id: 会话ID
|
||
user_id: 用户ID
|
||
sync_dsl: 是否同步DSL变更到现有会话
|
||
**kwargs: Begin组件需要的其他参数
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}/completions"
|
||
data = {"question": question, "stream": stream, "sync_dsl": sync_dsl}
|
||
|
||
if session_id:
|
||
data["session_id"] = session_id
|
||
if user_id:
|
||
data["user_id"] = user_id
|
||
|
||
# 添加其他Begin组件参数
|
||
data.update(kwargs)
|
||
|
||
if stream:
|
||
return self._stream_request('POST', endpoint, json=data)
|
||
else:
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def list_agent_sessions(self, agent_id: str, page: int = 1, page_size: int = 30,
|
||
orderby: str = "create_time", desc: bool = True,
|
||
session_id: Optional[str] = None, user_id: Optional[str] = None,
|
||
dsl: bool = True) -> Dict[str, Any]:
|
||
"""
|
||
列出代理会话
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
session_id: 筛选会话ID
|
||
user_id: 筛选用户ID
|
||
dsl: 是否在响应中包含dsl字段
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}/sessions"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc,
|
||
"dsl": dsl
|
||
}
|
||
|
||
if session_id:
|
||
params["id"] = session_id
|
||
if user_id:
|
||
params["user_id"] = user_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
def delete_agent_sessions(self, agent_id: str, ids: Optional[List[str]] = None) -> Dict[str, Any]:
|
||
"""
|
||
删除代理会话
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
ids: 要删除的会话ID列表
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}/sessions"
|
||
data = {"ids": ids} if ids else {}
|
||
return self._request('DELETE', endpoint, json=data)
|
||
|
||
def get_related_questions(self, question: str, login_token: str) -> Dict[str, Any]:
|
||
"""
|
||
生成相关问题
|
||
注意:此API需要登录令牌而不是API密钥
|
||
|
||
Args:
|
||
question: 原始问题
|
||
login_token: 登录令牌
|
||
"""
|
||
endpoint = "/v1/sessions/related_questions"
|
||
headers = {
|
||
'Authorization': f'Bearer {login_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
data = {"question": question}
|
||
|
||
return self._request('POST', endpoint, headers=headers, json=data)
|
||
|
||
# ====================
|
||
# 代理管理
|
||
# ====================
|
||
|
||
def list_agents(self, page: int = 1, page_size: int = 30, orderby: str = "create_time",
|
||
desc: bool = True, name: Optional[str] = None,
|
||
agent_id: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
列出代理
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页数量
|
||
orderby: 排序字段
|
||
desc: 是否降序
|
||
name: 筛选名称
|
||
agent_id: 筛选ID
|
||
"""
|
||
endpoint = "/api/v1/agents"
|
||
params = {
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"orderby": orderby,
|
||
"desc": desc
|
||
}
|
||
|
||
if name:
|
||
params["name"] = name
|
||
if agent_id:
|
||
params["id"] = agent_id
|
||
|
||
return self._request('GET', endpoint, params=params)
|
||
|
||
def create_agent(self, title: str, description: Optional[str] = None,
|
||
dsl: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
创建代理
|
||
|
||
Args:
|
||
title: 代理标题
|
||
description: 代理描述
|
||
dsl: Canvas DSL对象
|
||
"""
|
||
endpoint = "/api/v1/agents"
|
||
data = {"title": title}
|
||
|
||
if description is not None:
|
||
data["description"] = description
|
||
if dsl is not None:
|
||
data["dsl"] = dsl
|
||
|
||
return self._request('POST', endpoint, json=data)
|
||
|
||
def update_agent(self, agent_id: str, title: Optional[str] = None,
|
||
description: Optional[str] = None,
|
||
dsl: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||
"""
|
||
更新代理
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
title: 新标题
|
||
description: 新描述
|
||
dsl: 新DSL配置
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}"
|
||
data = {}
|
||
|
||
if title is not None:
|
||
data["title"] = title
|
||
if description is not None:
|
||
data["description"] = description
|
||
if dsl is not None:
|
||
data["dsl"] = dsl
|
||
|
||
return self._request('PUT', endpoint, json=data)
|
||
|
||
def delete_agent(self, agent_id: str) -> Dict[str, Any]:
|
||
"""
|
||
删除代理
|
||
|
||
Args:
|
||
agent_id: 代理ID
|
||
"""
|
||
endpoint = f"/api/v1/agents/{agent_id}"
|
||
return self._request('DELETE', endpoint)
|
||
|
||
|
||
# ====================
|
||
# 使用示例
|
||
# ====================
|
||
|
||
def example_usage():
|
||
"""
|
||
RAGFlow SDK使用示例
|
||
"""
|
||
# 初始化客户端
|
||
client = RAGFlowClient(
|
||
base_url="http://10.0.0.202:82",
|
||
api_key="ragflow-hlMjRmNzE2ODNiNTExZjA4ZTNlMDI0Mm"
|
||
)
|
||
|
||
try:
|
||
# 删除数据集
|
||
client.delete_datasets(ids=["afe1387883bb11f0a0fd0242ac170006"])
|
||
print("删除数据集成功")
|
||
|
||
# 1. 创建数据集, 都是默认配置
|
||
dataset = client.create_dataset(
|
||
name="我的数据集",
|
||
description="这是一个测试数据集",
|
||
chunk_method="naive"
|
||
)
|
||
dataset_id = dataset['data']['id']
|
||
print(f"创建数据集成功: {dataset_id}")
|
||
|
||
# 2. 上传文档
|
||
documents = client.upload_documents(
|
||
dataset_id=dataset_id,
|
||
file_paths=[
|
||
"/home/admin-root/haotian/康达瑞贝斯机器人后台/ruoyi-fastapi-backend/requirements.txt",
|
||
"/home/admin-root/haotian/康达瑞贝斯机器人后台/ruoyi-fastapi-backend/requirements-pg.txt"
|
||
]
|
||
)
|
||
print("文档上传成功")
|
||
|
||
# 3. 解析文档
|
||
document_ids = [doc['id'] for doc in documents['data']]
|
||
client.parse_documents(dataset_id, document_ids)
|
||
print("开始解析文档")
|
||
|
||
import time
|
||
time.sleep(5)
|
||
|
||
# 4. 创建聊天助手
|
||
chat_assistant = client.create_chat_assistant(
|
||
name="我的AI助手",
|
||
dataset_ids=[dataset_id]
|
||
)
|
||
chat_id = chat_assistant['data']['id']
|
||
print(f"创建聊天助手成功: {chat_id}")
|
||
|
||
# 5. 创建会话
|
||
session = client.create_session_with_chat(
|
||
chat_id=chat_id,
|
||
name="测试会话"
|
||
)
|
||
session_id = session['data']['id']
|
||
print(f"创建会话成功: {session_id}")
|
||
|
||
# 6. 开始对话(流式)
|
||
responses = client.converse_with_chat_assistant(
|
||
chat_id=chat_id,
|
||
question="你好,请介绍一下自己",
|
||
stream=True,
|
||
session_id=session_id
|
||
)
|
||
|
||
print("AI回复:")
|
||
for response in responses:
|
||
if response.get('data') and isinstance(response['data'], dict):
|
||
answer = response['data'].get('answer', '')
|
||
if answer:
|
||
print(answer, end='', flush=True)
|
||
print()
|
||
|
||
# 7. 检索相关文档块
|
||
chunks = client.retrieve_chunks(
|
||
question="RAGFlow的优势是什么?",
|
||
dataset_ids=[dataset_id],
|
||
top_k=5,
|
||
highlight=True
|
||
)
|
||
print(f"检索到 {chunks['data']['total']} 个相关文档块")
|
||
|
||
# 8. 列出数据集
|
||
datasets = client.list_datasets(page=1, page_size=10)
|
||
print(f"当前有 {len(datasets['data'])} 个数据集")
|
||
|
||
except RAGFlowError as e:
|
||
print(f"RAGFlow API错误: {e}")
|
||
except Exception as e:
|
||
print(f"其他错误: {e}")
|
||
|
||
|
||
def example_usage_1():
|
||
""" 测试获取列表方法
|
||
"""
|
||
|
||
# 初始化客户端
|
||
client = RAGFlowClient(
|
||
base_url="http://10.0.0.202:82",
|
||
api_key="ragflow-hlMjRmNzE2ODNiNTExZjA4ZTNlMDI0Mm"
|
||
)
|
||
|
||
# 1. 获取数据集列表
|
||
results_dataset = client.list_datasets()
|
||
print(f"获取数据集列表成功,共有 {len(results_dataset['data'])} 个数据集")
|
||
print("数据集为id:\n", [result["id"] for result in results_dataset['data']])
|
||
# 2. 获取数据集中文档列表
|
||
for result in results_dataset['data']:
|
||
print(f"数据集 {result['id']} 的文档列表为:")
|
||
results_doc = client.list_documents(dataset_id=result['id'])
|
||
# 文档名称
|
||
print([t["name"] for t in results_doc["data"]["docs"]])
|
||
|
||
# results = client.list_documents(dataset_id="d01")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 运行示例
|
||
# example_usage()
|
||
|
||
example_usage_1() |