改成同步方法
This commit is contained in:
parent
2bd083c09f
commit
0170daa2ec
@ -1,314 +1,35 @@
|
||||
# from datetime import datetime
|
||||
"""
|
||||
简化的RAGFlow控制器 - 移除异步复杂性
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
# from pydantic_validation_decorator import ValidateFields
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from config.enums import BusinessType
|
||||
# from config.get_db import get_db
|
||||
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
|
||||
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
|
||||
# from module_admin.entity.vo.user_vo import CurrentUserModel
|
||||
from fastapi_controller.decorators import Controller
|
||||
|
||||
from module_admin.service.login_service import LoginService
|
||||
from module_admin.service.ragflow_service import RAGFlowService
|
||||
from module_admin.service.search_service import SearchService, SearchServiceError
|
||||
|
||||
from module_admin.service.search_service import SearchService
|
||||
from utils.log_util import logger
|
||||
# from utils.page_util import PageResponseModel
|
||||
from utils.response_util import ResponseUtil
|
||||
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
|
||||
CreateSessionWithChatModel, ConverseWithChatAssistantModel
|
||||
# from config.env import RAGFlowConfig
|
||||
|
||||
# 使用Controller装饰器而不是APIRouter
|
||||
ragflowController = Controller(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||||
|
||||
|
||||
|
||||
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||||
def format_sse(data: dict, event: str | None = None) -> str:
|
||||
"""格式化SSE数据"""
|
||||
payload = json.dumps(data, ensure_ascii=False)
|
||||
prefix = f'event: {event}\n' if event else ''
|
||||
return f'{prefix}data: {payload}\n\n'
|
||||
|
||||
|
||||
|
||||
# 查看数据集列表
|
||||
@ragflowController.post("/dataset_list"
|
||||
# , response_model=PageResponseModel
|
||||
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
|
||||
)
|
||||
async def get_system_ragflow_list(
|
||||
request: Request,
|
||||
rage_flow_dastset_query: RagflowListQueryModel ,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
|
||||
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 创建数据集
|
||||
@ragflowController.post('/create_dataset')
|
||||
async def create_dataset(
|
||||
request: Request,
|
||||
create_dataset_params: CreateDatasetModel,
|
||||
):
|
||||
|
||||
result = await RAGFlowService.create_dataset_services(create_dataset_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 更新数据集
|
||||
@ragflowController.post('/update_dataset/{dataset_id}')
|
||||
async def update_dataset(
|
||||
request: Request,
|
||||
dataset_id: str,
|
||||
update_dataset_params: CreateDatasetModel,
|
||||
):
|
||||
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 列出数据集中文档列表
|
||||
@ragflowController.get("/list_documents/{dataset_id}")
|
||||
async def list_documents_by_dataset_id(
|
||||
request: Request,
|
||||
dataset_id: str,
|
||||
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
|
||||
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
列出数据集中文档列表
|
||||
"""
|
||||
print(list_documents_query)
|
||||
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 上传文件到数据集
|
||||
@ragflowController.post("/upload_file/{dataset_id}")
|
||||
async def upload_file_dataset(
|
||||
dataset_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
上传文件到数据集
|
||||
"""
|
||||
# print(file)
|
||||
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 更新文档
|
||||
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
|
||||
async def update_file_dataset(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
update_params: UpdateFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新文件到数据集
|
||||
"""
|
||||
# print(file)
|
||||
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 开始解析文档
|
||||
@ragflowController.post('/parse_documents/{dataset_id}')
|
||||
async def parse_documents(
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 停止解析文档
|
||||
@ragflowController.post('/stop_parse_documents/{dataset_id}')
|
||||
async def stop_parse_documents(
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 删除文档
|
||||
@ragflowController.post('/delete_file/{dataset_id}')
|
||||
async def delete_file(
|
||||
dataset_id: str,
|
||||
delete_params: DeleteFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除文件
|
||||
"""
|
||||
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 删除数据集
|
||||
@ragflowController.post('/delete_datasets')
|
||||
async def delete_datasets(
|
||||
delete_params: DeleteFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除数据集
|
||||
"""
|
||||
result = await RAGFlowService.delete_datasets_services(delete_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 查看聊天助手列表
|
||||
@ragflowController.post('/get_chat_assistant_list')
|
||||
async def get_chat_assistant_list(
|
||||
query_params: RagflowListQueryModel,
|
||||
):
|
||||
"""
|
||||
查看聊天助手列表
|
||||
"""
|
||||
|
||||
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
|
||||
return parse_result(result)
|
||||
|
||||
# pass
|
||||
|
||||
# 更新聊天助手
|
||||
@ragflowController.post('/update_chat_assistant')
|
||||
async def update_chat_assistant(
|
||||
update_params: UpdateChatAssistantModel,
|
||||
):
|
||||
"""
|
||||
更新聊天助手
|
||||
"""
|
||||
result = await RAGFlowService.update_chat_assistant_services(update_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 创建属于聊天助手的会话
|
||||
@ragflowController.post('/create_session_with_chat')
|
||||
async def create_session_with_chat(
|
||||
create_params: CreateSessionWithChatModel,
|
||||
):
|
||||
"""
|
||||
创建属于聊天助手的会话
|
||||
"""
|
||||
|
||||
result = await RAGFlowService.create_session_with_chat_services(create_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 与聊天助手进行对话
|
||||
@ragflowController.post('/converse_with_chat_assistant')
|
||||
async def converse_with_chat_assistant(
|
||||
request: Request,
|
||||
converse_params: ConverseWithChatAssistantModel,
|
||||
):
|
||||
"""
|
||||
与聊天助手进行对话
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
redis = getattr(request.app.state, 'redis', None)
|
||||
cache_key = None
|
||||
|
||||
# 检查是否应该使用搜索服务 (Router Strategy)
|
||||
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
|
||||
t0 = time.perf_counter()
|
||||
intent = await SearchService.classify_intent(converse_params.question)
|
||||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
|
||||
|
||||
if intent == 'SEARCH':
|
||||
return await SearchService.handle_search_chat(converse_params, redis)
|
||||
|
||||
if not converse_params.stream and redis:
|
||||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||||
return ResponseUtil.success(json.loads(cached))
|
||||
|
||||
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
|
||||
t1 = time.perf_counter()
|
||||
# 修复:await async方法获取AsyncGenerator
|
||||
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||||
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
|
||||
|
||||
if converse_params.stream:
|
||||
async def stream_response():
|
||||
last_answer = ""
|
||||
first_token_received = False
|
||||
start_stream_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
async for chunk in result:
|
||||
if not first_token_received:
|
||||
first_token_received = True
|
||||
latency = time.perf_counter() - start_stream_time
|
||||
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
|
||||
|
||||
# Native RAGFlow returns cumulative text in 'answer' field
|
||||
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
|
||||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||||
|
||||
if isinstance(body, dict) and 'answer' in body:
|
||||
# Inspect for errors even in stream
|
||||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||||
logger.error(f"RAGFlow Stream Error: {chunk}")
|
||||
|
||||
current_answer = body['answer']
|
||||
# Calculate Delta
|
||||
if current_answer.startswith(last_answer):
|
||||
delta = current_answer[len(last_answer):]
|
||||
if delta:
|
||||
body['answer'] = delta # Send only the new part
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
# Context reset (unlikely but safe fallback)
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
yield format_sse(body)
|
||||
|
||||
yield format_sse({'status': 'completed'}, event='end')
|
||||
except Exception as exc:
|
||||
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
|
||||
yield format_sse({'message': str(exc)}, event='error')
|
||||
finally:
|
||||
total_time = time.perf_counter() - start_time
|
||||
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
|
||||
|
||||
return StreamingResponse(stream_response(), media_type='text/event-stream')
|
||||
|
||||
response = parse_result(result)
|
||||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||||
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
|
||||
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
# return parse_result(result)
|
||||
|
||||
# 获取用户权限
|
||||
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
|
||||
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
|
||||
"""
|
||||
获取用户权限
|
||||
"""
|
||||
|
||||
user_auth_list = current_user.permissions
|
||||
print(user_auth_list)
|
||||
|
||||
|
||||
return ResponseUtil.success(data=user_auth_list)
|
||||
|
||||
def parse_result(result):
|
||||
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""统一结果解析"""
|
||||
code = result.get('code', 0)
|
||||
if code != 0:
|
||||
msg = result.get('message') or result.get('msg') or '接口异常'
|
||||
@ -317,11 +38,188 @@ def parse_result(result):
|
||||
|
||||
|
||||
def build_chat_cache_key(chat_id: str, question: str) -> str:
|
||||
"""构建聊天缓存键"""
|
||||
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
|
||||
return f'ragflow:chat:{chat_id}:{digest}'
|
||||
|
||||
|
||||
def format_sse(data: dict, event: str | None = None) -> str:
|
||||
payload = json.dumps(data, ensure_ascii=False)
|
||||
prefix = f'event: {event}\n' if event else ''
|
||||
return f'{prefix}data: {payload}\n\n'
|
||||
# 非流式对话接口
|
||||
@ragflowController.post('/converse_with_chat_assistant')
|
||||
async def converse_with_chat_assistant(
|
||||
request: Request,
|
||||
converse_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""
|
||||
与聊天助手进行对话 - 同步版本,简化复杂性
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# 获取redis实例
|
||||
redis = getattr(request.app.state, 'redis', None)
|
||||
cache_key = None
|
||||
|
||||
# 检查是否应该使用搜索服务
|
||||
try:
|
||||
intent = await SearchService.classify_intent(converse_params.question)
|
||||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
|
||||
|
||||
if intent == 'SEARCH':
|
||||
return await SearchService.handle_search_chat(converse_params, redis)
|
||||
except Exception as e:
|
||||
logger.warning(f"意图分类失败,使用RAG服务: {e}")
|
||||
|
||||
# 缓存检查
|
||||
if not converse_params.stream and redis:
|
||||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||||
try:
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||||
return ResponseUtil.success(json.loads(cached))
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存获取失败: {e}")
|
||||
|
||||
# 直接使用同步RAGFlow服务
|
||||
try:
|
||||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||||
|
||||
# 流式响应
|
||||
if converse_params.stream:
|
||||
return StreamingResponse(
|
||||
stream_ragflow_response(result, converse_params.chat_id, start_time),
|
||||
media_type='text/event-stream'
|
||||
)
|
||||
|
||||
# 非流式响应
|
||||
response = parse_result(result)
|
||||
|
||||
# 设置缓存
|
||||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||||
try:
|
||||
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存设置失败: {e}")
|
||||
|
||||
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
|
||||
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
|
||||
|
||||
|
||||
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式处理RAGFlow响应 - 同步版本
|
||||
"""
|
||||
last_answer = ""
|
||||
first_token_received = False
|
||||
start_stream_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
# 检查第一个token的延迟
|
||||
if not first_token_received:
|
||||
first_token_received = True
|
||||
latency = time.perf_counter() - start_stream_time
|
||||
logger.info(f"[RAG_PROFILE] Time to First Token ({chat_id}): {latency:.3f}s")
|
||||
|
||||
# 处理chunk数据
|
||||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||||
|
||||
# 处理错误
|
||||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||||
logger.error(f"RAGFlow Stream Error: {chunk}")
|
||||
yield format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
|
||||
continue
|
||||
|
||||
# 处理answer字段的增量更新
|
||||
if isinstance(body, dict) and 'answer' in body:
|
||||
current_answer = body['answer']
|
||||
|
||||
# 计算增量内容
|
||||
if current_answer.startswith(last_answer):
|
||||
delta = current_answer[len(last_answer):]
|
||||
if delta:
|
||||
body['answer'] = delta
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
# 上下文重置的备用处理
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
yield format_sse(body)
|
||||
|
||||
# 流结束
|
||||
yield format_sse({'status': 'completed'}, event='end')
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f'[RAG_PROFILE] ragflow流式对话异常: {exc}')
|
||||
yield format_sse({'message': str(exc)}, event='error')
|
||||
finally:
|
||||
total_time = time.perf_counter() - start_time
|
||||
logger.info(f'[RAG_PROFILE] Total Stream Duration ({chat_id}): {total_time:.3f}s')
|
||||
|
||||
|
||||
# 聊天助手列表
|
||||
@ragflowController.post('/get_chat_assistant_list')
|
||||
async def get_chat_assistant_list(
|
||||
query_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""获取聊天助手列表"""
|
||||
result = RAGFlowService.get_chat_assistant_list_services(query_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 创建会话
|
||||
@ragflowController.post('/create_session_with_chat')
|
||||
async def create_session_with_chat(
|
||||
create_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""创建聊天会话"""
|
||||
result = RAGFlowService.create_session_with_chat_services(create_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 更新聊天助手
|
||||
@ragflowController.post('/update_chat_assistant')
|
||||
async def update_chat_assistant(
|
||||
update_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""更新聊天助手"""
|
||||
result = RAGFlowService.update_chat_assistant_services(update_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 数据集相关接口
|
||||
@ragflowController.post('/dataset_list')
|
||||
async def get_system_ragflow_list(
|
||||
request: Request,
|
||||
ragflow_list_query: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""获取数据集列表"""
|
||||
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
@ragflowController.post('/create_dataset')
|
||||
async def create_dataset(
|
||||
create_dataset_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""创建数据集"""
|
||||
result = RAGFlowService.create_dataset_services(create_dataset_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
@ragflowController.post('/delete_datasets')
|
||||
async def delete_datasets(
|
||||
delete_params: Any, # 使用Any避免导入具体模型类
|
||||
):
|
||||
"""删除数据集"""
|
||||
result = RAGFlowService.delete_datasets_services(delete_params)
|
||||
return parse_result(result)
|
||||
@ -0,0 +1,376 @@
|
||||
# from datetime import datetime
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import List
|
||||
from fastapi import APIRouter, Depends, Request, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
# from pydantic_validation_decorator import ValidateFields
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from config.enums import BusinessType
|
||||
# from config.get_db import get_db
|
||||
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
|
||||
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
|
||||
# from module_admin.entity.vo.user_vo import CurrentUserModel
|
||||
from module_admin.service.login_service import LoginService
|
||||
from module_admin.service.ragflow_service import RAGFlowService
|
||||
from module_admin.service.search_service import SearchService, SearchServiceError
|
||||
|
||||
from utils.log_util import logger
|
||||
# from utils.page_util import PageResponseModel
|
||||
from utils.response_util import ResponseUtil
|
||||
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
|
||||
CreateSessionWithChatModel, ConverseWithChatAssistantModel
|
||||
# from config.env import RAGFlowConfig
|
||||
import asyncio
|
||||
|
||||
|
||||
|
||||
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||||
|
||||
|
||||
|
||||
# 查看数据集列表
|
||||
@ragflowController.post("/dataset_list"
|
||||
# , response_model=PageResponseModel
|
||||
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
|
||||
)
|
||||
async def get_system_ragflow_list(
|
||||
request: Request,
|
||||
rage_flow_dastset_query: RagflowListQueryModel ,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
|
||||
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 创建数据集
|
||||
@ragflowController.post('/create_dataset')
|
||||
async def create_dataset(
|
||||
request: Request,
|
||||
create_dataset_params: CreateDatasetModel,
|
||||
):
|
||||
|
||||
result = await RAGFlowService.create_dataset_services(create_dataset_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 更新数据集
|
||||
@ragflowController.post('/update_dataset/{dataset_id}')
|
||||
async def update_dataset(
|
||||
request: Request,
|
||||
dataset_id: str,
|
||||
update_dataset_params: CreateDatasetModel,
|
||||
):
|
||||
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 列出数据集中文档列表
|
||||
@ragflowController.get("/list_documents/{dataset_id}")
|
||||
async def list_documents_by_dataset_id(
|
||||
request: Request,
|
||||
dataset_id: str,
|
||||
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
|
||||
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
列出数据集中文档列表
|
||||
"""
|
||||
print(list_documents_query)
|
||||
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 上传文件到数据集
|
||||
@ragflowController.post("/upload_file/{dataset_id}")
|
||||
async def upload_file_dataset(
|
||||
dataset_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
上传文件到数据集
|
||||
"""
|
||||
# print(file)
|
||||
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 更新文档
|
||||
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
|
||||
async def update_file_dataset(
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
update_params: UpdateFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新文件到数据集
|
||||
"""
|
||||
# print(file)
|
||||
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
# 开始解析文档
|
||||
@ragflowController.post('/parse_documents/{dataset_id}')
|
||||
async def parse_documents(
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 停止解析文档
|
||||
@ragflowController.post('/stop_parse_documents/{dataset_id}')
|
||||
async def stop_parse_documents(
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 删除文档
|
||||
@ragflowController.post('/delete_file/{dataset_id}')
|
||||
async def delete_file(
|
||||
dataset_id: str,
|
||||
delete_params: DeleteFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除文件
|
||||
"""
|
||||
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
|
||||
|
||||
return parse_result(result)
|
||||
|
||||
|
||||
# 删除数据集
|
||||
@ragflowController.post('/delete_datasets')
|
||||
async def delete_datasets(
|
||||
delete_params: DeleteFileModel,
|
||||
# query_db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
删除数据集
|
||||
"""
|
||||
result = await RAGFlowService.delete_datasets_services(delete_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 查看聊天助手列表
|
||||
@ragflowController.post('/get_chat_assistant_list')
|
||||
async def get_chat_assistant_list(
|
||||
query_params: RagflowListQueryModel,
|
||||
):
|
||||
"""
|
||||
查看聊天助手列表
|
||||
"""
|
||||
|
||||
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
|
||||
return parse_result(result)
|
||||
|
||||
# pass
|
||||
|
||||
# 更新聊天助手
|
||||
@ragflowController.post('/update_chat_assistant')
|
||||
async def update_chat_assistant(
|
||||
update_params: UpdateChatAssistantModel,
|
||||
):
|
||||
"""
|
||||
更新聊天助手
|
||||
"""
|
||||
result = await RAGFlowService.update_chat_assistant_services(update_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 创建属于聊天助手的会话
|
||||
@ragflowController.post('/create_session_with_chat')
|
||||
async def create_session_with_chat(
|
||||
create_params: CreateSessionWithChatModel,
|
||||
):
|
||||
"""
|
||||
创建属于聊天助手的会话
|
||||
"""
|
||||
|
||||
result = await RAGFlowService.create_session_with_chat_services(create_params)
|
||||
return parse_result(result)
|
||||
|
||||
# 与聊天助手进行对话
|
||||
@ragflowController.post('/converse_with_chat_assistant')
|
||||
async def converse_with_chat_assistant(
|
||||
request: Request,
|
||||
converse_params: ConverseWithChatAssistantModel,
|
||||
):
|
||||
"""
|
||||
与聊天助手进行对话
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
redis = getattr(request.app.state, 'redis', None)
|
||||
cache_key = None
|
||||
|
||||
# 检查是否应该使用搜索服务 (Router Strategy)
|
||||
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
|
||||
t0 = time.perf_counter()
|
||||
intent = await SearchService.classify_intent(converse_params.question)
|
||||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
|
||||
|
||||
if intent == 'SEARCH':
|
||||
return await SearchService.handle_search_chat(converse_params, redis)
|
||||
|
||||
if not converse_params.stream and redis:
|
||||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||||
cached = await redis.get(cache_key)
|
||||
if cached:
|
||||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||||
return ResponseUtil.success(json.loads(cached))
|
||||
|
||||
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
|
||||
t1 = time.perf_counter()
|
||||
# 使用同步Generator,更简单
|
||||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||||
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
|
||||
|
||||
if converse_params.stream:
|
||||
async def stream_response():
|
||||
last_answer = ""
|
||||
first_token_received = False
|
||||
start_stream_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# 将同步Generator转换为异步迭代
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
for chunk in result:
|
||||
# 在事件循环中运行CPU密集型操作
|
||||
await asyncio.sleep(0) # 让出控制权,允许其他协程运行
|
||||
|
||||
if not first_token_received:
|
||||
first_token_received = True
|
||||
latency = time.perf_counter() - start_stream_time
|
||||
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
|
||||
|
||||
# Native RAGFlow returns cumulative text in 'answer' field
|
||||
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
|
||||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||||
|
||||
if isinstance(body, dict) and 'answer' in body:
|
||||
# Inspect for errors even in stream
|
||||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||||
logger.error(f"RAGFlow Stream Error: {chunk}")
|
||||
|
||||
current_answer = body['answer']
|
||||
# Calculate Delta
|
||||
if current_answer.startswith(last_answer):
|
||||
delta = current_answer[len(last_answer):]
|
||||
if delta:
|
||||
body['answer'] = delta # Send only the new part
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
# Context reset (unlikely but safe fallback)
|
||||
last_answer = current_answer
|
||||
yield format_sse(body)
|
||||
else:
|
||||
yield format_sse(body)
|
||||
|
||||
yield format_sse({'status': 'completed'}, event='end')
|
||||
except Exception as exc:
|
||||
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
|
||||
yield format_sse({'message': str(exc)}, event='error')
|
||||
finally:
|
||||
total_time = time.perf_counter() - start_time
|
||||
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
|
||||
|
||||
return StreamingResponse(stream_response(), media_type='text/event-stream')
|
||||
|
||||
response = parse_result(result)
|
||||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||||
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
|
||||
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
# return parse_result(result)
|
||||
|
||||
# 获取用户权限
|
||||
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
|
||||
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
|
||||
"""
|
||||
获取用户权限
|
||||
"""
|
||||
|
||||
user_auth_list = current_user.permissions
|
||||
print(user_auth_list)
|
||||
|
||||
|
||||
return ResponseUtil.success(data=user_auth_list)
|
||||
|
||||
def parse_result(result):
|
||||
code = result.get('code', 0)
|
||||
if code != 0:
|
||||
msg = result.get('message') or result.get('msg') or '接口异常'
|
||||
return ResponseUtil.error(msg=msg, data=result.get('data', None))
|
||||
return ResponseUtil.success(data=result.get('data', None))
|
||||
|
||||
|
||||
def build_chat_cache_key(chat_id: str, question: str) -> str:
|
||||
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
|
||||
return f'ragflow:chat:{chat_id}:{digest}'
|
||||
|
||||
|
||||
def format_sse(data: dict, event: str | None = None) -> str:
|
||||
payload = json.dumps(data, ensure_ascii=False)
|
||||
prefix = f'event: {event}\n' if event else ''
|
||||
return f'{prefix}data: {payload}\n\n'
|
||||
|
||||
|
||||
async def aiter_sync_generator(sync_gen):
|
||||
"""将同步 Generator 转换为异步迭代器"""
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
while True:
|
||||
# 在线程池中获取下一个值
|
||||
value = await loop.run_in_executor(None, next, sync_gen, None)
|
||||
if value is None:
|
||||
break
|
||||
yield value
|
||||
except StopIteration:
|
||||
pass
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import aiofiles
|
||||
import redis
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi_controller.decorators import Controller
|
||||
from fastapi_controller.dependencies import LoginService
|
||||
from fastapi_controller.response import ResponseUtil
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from module_admin.service.ragflow_service import RAGFlowService
|
||||
from module_admin.service.search_service import SearchService
|
||||
from module_admin.service.login_service import CheckRoleInterfaceAuth
|
||||
from module_admin.model.ragflow_models import (
|
||||
ConverseWithChatAssistantModel,
|
||||
CreateSessionWithChatModel,
|
||||
DeleteFileModel,
|
||||
DocumentIdsModel,
|
||||
RagflowListQueryModel,
|
||||
UpdateChatAssistantModel,
|
||||
)
|
||||
|
||||
import logger as logger_module
|
||||
@ -1,196 +1,154 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel \
|
||||
,CreateSessionWithChatModel, ConverseWithChatAssistantModel
|
||||
from module_admin.entity.vo.ragflow_vo import (
|
||||
RagflowListQueryModel,
|
||||
ListDocumentsQueryModel,
|
||||
UpdateFileModel,
|
||||
DeleteFileModel,
|
||||
CreateDatasetModel,
|
||||
DocumentIdsModel,
|
||||
UpdateChatAssistantModel,
|
||||
CreateSessionWithChatModel,
|
||||
ConverseWithChatAssistantModel
|
||||
)
|
||||
from utils.ragflow_util import RAGFlowClient
|
||||
from config.env import RAGFlowConfig
|
||||
from utils.ragflow_client_manager import get_ragflow_client
|
||||
|
||||
|
||||
class RAGFlowService:
|
||||
"""
|
||||
RAGFlow服务
|
||||
RAGFlow服务 - 简化版本,使用同步操作
|
||||
"""
|
||||
|
||||
# 获取数据集列表
|
||||
@classmethod
|
||||
async def get_ragflow_dataset_list_services(cls, query_db: AsyncSession, rage_flow_query: RagflowListQueryModel):
|
||||
def get_ragflow_dataset_list_services(cls, query_db, rage_flow_query: RagflowListQueryModel):
|
||||
"""
|
||||
获取数据集列表
|
||||
获取数据集列表 - 同步版本
|
||||
"""
|
||||
|
||||
client = await get_ragflow_client()
|
||||
result = await client.list_datasets(**(rage_flow_query.model_dump()))
|
||||
|
||||
# 获取分页数据
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.list_datasets(**(rage_flow_query.model_dump()))
|
||||
return result
|
||||
|
||||
# 创建数据集
|
||||
@classmethod
|
||||
async def create_dataset_services(cls, create_dataset_params: CreateDatasetModel):
|
||||
"""创建数据集
|
||||
|
||||
Args:
|
||||
create_dataset_params (CreateDatasetModel): 创建参数
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
client = await get_ragflow_client()
|
||||
result = await client.create_dataset(
|
||||
**(create_dataset_params.model_dump())
|
||||
)
|
||||
|
||||
def create_dataset_services(cls, create_dataset_params: CreateDatasetModel):
|
||||
"""创建数据集 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.create_dataset(**(create_dataset_params.model_dump()))
|
||||
return result
|
||||
|
||||
# 更新数据集
|
||||
@classmethod
|
||||
async def update_dataset_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
update_dataset_params: CreateDatasetModel,
|
||||
):
|
||||
"""更新数据集信息
|
||||
|
||||
Args:
|
||||
dataset_id (str): 数据集id
|
||||
update_dataset_params (CreateDatasetModel): 更新参数
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
client = await get_ragflow_client()
|
||||
result = await client.update_dataset(
|
||||
dataset_id=dataset_id, **(update_dataset_params.model_dump())
|
||||
)
|
||||
|
||||
def update_dataset_services(cls, dataset_id: str, update_dataset_params: CreateDatasetModel):
|
||||
"""更新数据集信息 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.update_dataset(dataset_id=dataset_id, **(update_dataset_params.model_dump()))
|
||||
return result
|
||||
|
||||
|
||||
# 获取数据集中文档列表
|
||||
@classmethod
|
||||
async def list_documents_services(
|
||||
cls,
|
||||
query_db: AsyncSession,
|
||||
dataset_id: str,
|
||||
list_documents_query: ListDocumentsQueryModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.list_documents(dataset_id=dataset_id, **(list_documents_query.model_dump()))
|
||||
|
||||
def list_documents_services(cls, query_db, dataset_id: str, list_documents_query: ListDocumentsQueryModel):
|
||||
"""获取文档列表 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.list_documents(dataset_id=dataset_id, **(list_documents_query.model_dump()))
|
||||
return result
|
||||
|
||||
# 上传文档到数据集
|
||||
@classmethod
|
||||
async def upload_file_dataset_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
files,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.upload_documents_bytes(dataset_id=dataset_id, file_bytes=files)
|
||||
|
||||
def upload_file_dataset_services(cls, dataset_id: str, files):
|
||||
"""上传文档 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.upload_documents_bytes(dataset_id=dataset_id, file_bytes=files)
|
||||
return result
|
||||
|
||||
# 开始解析文档
|
||||
@classmethod
|
||||
async def parse_documents_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.parse_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
|
||||
|
||||
def parse_documents_services(cls, dataset_id: str, parse_params: DocumentIdsModel):
|
||||
"""解析文档 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.parse_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
|
||||
return result
|
||||
|
||||
# 停止解析文档
|
||||
@classmethod
|
||||
async def stop_parse_documents_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
parse_params: DocumentIdsModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.stop_parsing_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
|
||||
|
||||
def stop_parse_documents_services(cls, dataset_id: str, parse_params: DocumentIdsModel):
|
||||
"""停止解析文档 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.stop_parsing_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
|
||||
return result
|
||||
|
||||
|
||||
# 更新文档内容
|
||||
@classmethod
|
||||
async def update_file_dataset_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
update_params: UpdateFileModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.update_document(dataset_id=dataset_id, document_id=document_id, **(update_params.model_dump()))
|
||||
|
||||
def update_file_dataset_services(cls, dataset_id: str, document_id: str, update_params: UpdateFileModel):
|
||||
"""更新文档 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.update_document(dataset_id=dataset_id, document_id=document_id, **(update_params.model_dump()))
|
||||
return result
|
||||
|
||||
# 删除文档
|
||||
@classmethod
|
||||
async def delete_file_services(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
delete_params: DeleteFileModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.delete_documents(dataset_id=dataset_id, **(delete_params.model_dump()))
|
||||
def delete_file_services(cls, dataset_id: str, delete_params: DeleteFileModel):
|
||||
"""删除文档 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.delete_documents(dataset_id=dataset_id, **(delete_params.model_dump()))
|
||||
return result
|
||||
|
||||
# 删除数据集
|
||||
@classmethod
|
||||
async def delete_datasets_services(
|
||||
cls,
|
||||
delete_params: DeleteFileModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.delete_datasets(**(delete_params.model_dump()))
|
||||
def delete_datasets_services(cls, delete_params: DeleteFileModel):
|
||||
"""删除数据集 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.delete_datasets(**(delete_params.model_dump()))
|
||||
return result
|
||||
|
||||
|
||||
# 查看聊天助手列表
|
||||
@classmethod
|
||||
async def get_chat_assistant_list_services(
|
||||
cls,
|
||||
query_params: RagflowListQueryModel,
|
||||
):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.list_chat_assistants(**(query_params.model_dump()))
|
||||
def get_chat_assistant_list_services(cls, query_params: RagflowListQueryModel):
|
||||
"""获取聊天助手列表 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.list_chat_assistants(**(query_params.model_dump()))
|
||||
return result
|
||||
|
||||
# 修改聊天助手
|
||||
@classmethod
|
||||
async def update_chat_assistant_services(cls, update_params: UpdateChatAssistantModel):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.update_chat_assistant(**(update_params.model_dump()))
|
||||
def update_chat_assistant_services(cls, update_params: UpdateChatAssistantModel):
|
||||
"""更新聊天助手 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.update_chat_assistant(**(update_params.model_dump()))
|
||||
return result
|
||||
|
||||
# 创建助手会话
|
||||
@classmethod
|
||||
async def create_session_with_chat_services(cls, create_params: CreateSessionWithChatModel):
|
||||
client = await get_ragflow_client()
|
||||
result = await client.create_session_with_chat(**(create_params.model_dump()))
|
||||
def create_session_with_chat_services(cls, create_params: CreateSessionWithChatModel):
|
||||
"""创建会话 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
result = client.create_session_with_chat(**(create_params.model_dump()))
|
||||
return result
|
||||
|
||||
|
||||
# 与助手聊天
|
||||
# 与助手聊天 - 核心方法
|
||||
@classmethod
|
||||
async def converse_with_chat_assistant_services(cls, converse_params: ConverseWithChatAssistantModel):
|
||||
client = await get_ragflow_client()
|
||||
# 修复:直接返回AsyncGenerator,不使用await消费流式数据
|
||||
return client.converse_with_chat_assistant(**(converse_params.model_dump()))
|
||||
def converse_with_chat_assistant_services(cls, converse_params: ConverseWithChatAssistantModel):
|
||||
"""与聊天助手对话 - 同步版本,返回Generator"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
# 直接返回Generator,支持流式响应
|
||||
return client.converse_with_chat_assistant(
|
||||
chat_id=converse_params.chat_id,
|
||||
question=converse_params.question,
|
||||
stream=converse_params.stream,
|
||||
session_id=converse_params.session_id
|
||||
)
|
||||
|
||||
# 与助手聊天 (OpenAI Compatible)
|
||||
# 与助手聊天 (OpenAI Compatible) - 同步版本
|
||||
@classmethod
|
||||
async def converse_with_chat_assistant_services_openai(cls, converse_params: ConverseWithChatAssistantModel):
|
||||
client = await get_ragflow_client()
|
||||
def converse_with_chat_assistant_services_openai(cls, converse_params: ConverseWithChatAssistantModel):
|
||||
"""OpenAI兼容格式 - 同步版本"""
|
||||
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
|
||||
# Construct messages list for OpenAI format
|
||||
messages = [{"role": "user", "content": converse_params.question}]
|
||||
# Uses defaults for model name as per user indication "server will parse this automatically"
|
||||
return await client.create_chat_completion(
|
||||
return client.create_chat_completion(
|
||||
chat_id=converse_params.chat_id,
|
||||
model="ragflow",
|
||||
messages=messages,
|
||||
stream=converse_params.stream
|
||||
)
|
||||
|
||||
|
||||
|
||||
)
|
||||
@ -189,8 +189,8 @@ class SearchService:
|
||||
except SearchServiceError as exc:
|
||||
logger.warning('搜索服务失败,降级到RAGFlow: %s', exc)
|
||||
# Fallback to normal flow
|
||||
# 修复:await async方法获取AsyncGenerator
|
||||
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||||
# 修复:直接调用同步方法,不再使用await
|
||||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||||
|
||||
if converse_params.stream:
|
||||
async def stream_response():
|
||||
|
||||
193
ruoyi-fastapi-backend/test_simplified_implementation.py
Normal file
193
ruoyi-fastapi-backend/test_simplified_implementation.py
Normal file
@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简化后的RAGFlow实现测试脚本
|
||||
验证同步Generator在异步环境中的工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Generator, Any
|
||||
|
||||
# 模拟同步RAGFlow客户端
|
||||
class MockSyncRAGFlowClient:
|
||||
"""模拟同步RAGFlow客户端,返回Generator"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str):
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
|
||||
def converse_with_chat_assistant(self, **kwargs) -> Generator[dict, None, None]:
|
||||
"""模拟流式对话,返回同步Generator"""
|
||||
print(f"Mock client: 开始生成流式数据...")
|
||||
|
||||
# 模拟流式数据生成
|
||||
responses = [
|
||||
{'data': {'answer': 'Hello'}},
|
||||
{'data': {'answer': 'Hello, I am'}},
|
||||
{'data': {'answer': 'Hello, I am an AI'}},
|
||||
{'data': {'answer': 'Hello, I am an AI assistant'}},
|
||||
{'data': {'answer': 'Hello, I am an AI assistant.'}},
|
||||
]
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
print(f"Mock client: 生成第{i+1}个数据块")
|
||||
time.sleep(0.5) # 模拟网络延迟
|
||||
yield response
|
||||
|
||||
print(f"Mock client: 流式数据生成完成")
|
||||
|
||||
|
||||
# 模拟同步RAGFlowService
|
||||
class MockRAGFlowService:
|
||||
"""模拟同步RAGFlowService"""
|
||||
|
||||
@staticmethod
|
||||
def converse_with_chat_assistant_services(converse_params) -> Generator[dict, None, None]:
|
||||
"""返回同步Generator"""
|
||||
print("MockService: 调用converse_with_chat_assistant_services")
|
||||
client = MockSyncRAGFlowClient("http://localhost:9099", "test_key")
|
||||
return client.converse_with_chat_assistant(
|
||||
chat_id=converse_params.get('chat_id'),
|
||||
question=converse_params.get('question'),
|
||||
stream=True,
|
||||
session_id=converse_params.get('session_id')
|
||||
)
|
||||
|
||||
|
||||
# 模拟异步控制器
|
||||
async def async_controller_test():
|
||||
"""测试异步控制器中消费同步Generator"""
|
||||
|
||||
# 模拟参数
|
||||
params = {
|
||||
'chat_id': 'test_chat_123',
|
||||
'question': '你好,请介绍一下自己',
|
||||
'stream': True,
|
||||
'session_id': 'session_456'
|
||||
}
|
||||
|
||||
print("=" * 60)
|
||||
print("测试:异步控制器消费同步Generator")
|
||||
print("=" * 60)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
first_token_received = False
|
||||
|
||||
try:
|
||||
# 调用服务层(同步方法)
|
||||
print("1. 调用RAGFlowService.converse_with_chat_assistant_services...")
|
||||
result = MockRAGFlowService.converse_with_chat_assistant_services(params)
|
||||
print(f" 返回类型: {type(result)}")
|
||||
|
||||
if not isinstance(result, Generator):
|
||||
raise TypeError(f"期望Generator类型,但得到 {type(result)}")
|
||||
|
||||
# 在异步上下文中消费同步Generator
|
||||
print("2. 开始消费同步Generator...")
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
chunk_count += 1
|
||||
print(f" 接收到第{chunk_count}个数据块: {chunk}")
|
||||
|
||||
# 检查第一个token延迟
|
||||
if not first_token_received:
|
||||
first_token_received = True
|
||||
latency = time.perf_counter() - start_time
|
||||
print(f" 首Token延迟: {latency:.3f}s")
|
||||
|
||||
# 模拟处理每个chunk
|
||||
await asyncio.sleep(0.01) # 让出控制权
|
||||
|
||||
except Exception as e:
|
||||
print(f" 消费数据时出错: {e}")
|
||||
raise
|
||||
|
||||
total_time = time.perf_counter() - start_time
|
||||
print(f"3. 流式处理完成,总耗时: {total_time:.3f}s")
|
||||
print(f" 总共接收数据块: {chunk_count}")
|
||||
|
||||
# 验证结果
|
||||
if chunk_count == 5:
|
||||
print("✅ 测试通过:成功接收到所有5个数据块")
|
||||
else:
|
||||
print(f"❌ 测试失败:期望5个数据块,实际收到{chunk_count}个")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败:{e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 测试同步消费 vs 异步消费
|
||||
def sync_vs_async_test():
|
||||
"""测试同步消费和异步消费的差异"""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试:同步消费 vs 异步消费")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建同步Generator
|
||||
def sync_generator():
|
||||
for i in range(5):
|
||||
time.sleep(0.1)
|
||||
yield f"数据块 {i+1}"
|
||||
|
||||
generator = sync_generator()
|
||||
|
||||
# 1. 同步消费
|
||||
print("1. 同步消费测试:")
|
||||
start_time = time.perf_counter()
|
||||
for item in generator:
|
||||
print(f" {item}")
|
||||
sync_time = time.perf_counter() - start_time
|
||||
print(f" 同步消费耗时: {sync_time:.3f}s")
|
||||
|
||||
# 2. 异步消费
|
||||
print("\n2. 异步消费测试:")
|
||||
generator2 = sync_generator()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
async def async_consumer(gen):
|
||||
count = 0
|
||||
for item in gen:
|
||||
count += 1
|
||||
print(f" {item}")
|
||||
await asyncio.sleep(0.01) # 让出控制权
|
||||
return count
|
||||
|
||||
async def run_async_test():
|
||||
return await async_consumer(generator2)
|
||||
|
||||
try:
|
||||
count = asyncio.run(run_async_test())
|
||||
async_time = time.perf_counter() - start_time
|
||||
print(f" 异步消费耗时: {async_time:.3f}s")
|
||||
print(f" 处理了{count}个项目")
|
||||
except Exception as e:
|
||||
print(f" 异步消费失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("RAGFlow简化实现测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行主要测试
|
||||
await async_controller_test()
|
||||
|
||||
# 运行对比测试
|
||||
sync_vs_async_test()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试总结:")
|
||||
print("1. 同步Generator可以在异步环境中正常工作")
|
||||
print("2. 使用for循环可以自动处理同步Generator")
|
||||
print("3. 异步消费需要适当让出控制权(await)")
|
||||
print("4. 简化后的架构避免了复杂的async/await链")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Reference in New Issue
Block a user