改成同步方法

This commit is contained in:
Tian jianyong 2025-12-17 14:58:09 +08:00
parent 2bd083c09f
commit 0170daa2ec
5 changed files with 858 additions and 433 deletions

View File

@ -1,314 +1,35 @@
# from datetime import datetime
"""
简化的RAGFlow控制器 - 移除异步复杂性
"""
import hashlib
import json
import time
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File
from typing import Any, Dict, Generator
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
# from pydantic_validation_decorator import ValidateFields
# from sqlalchemy.ext.asyncio import AsyncSession
# from config.enums import BusinessType
# from config.get_db import get_db
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
# from module_admin.entity.vo.user_vo import CurrentUserModel
from fastapi_controller.decorators import Controller
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService, SearchServiceError
from module_admin.service.search_service import SearchService
from utils.log_util import logger
# from utils.page_util import PageResponseModel
from utils.response_util import ResponseUtil
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
CreateSessionWithChatModel, ConverseWithChatAssistantModel
# from config.env import RAGFlowConfig
# 使用Controller装饰器而不是APIRouter
ragflowController = Controller(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
def format_sse(data: dict, event: str | None = None) -> str:
"""格式化SSE数据"""
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
# 查看数据集列表
@ragflowController.post("/dataset_list"
# , response_model=PageResponseModel
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
)
async def get_system_ragflow_list(
request: Request,
rage_flow_dastset_query: RagflowListQueryModel ,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
return parse_result(result)
# 创建数据集
@ragflowController.post('/create_dataset')
async def create_dataset(
request: Request,
create_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
# 更新数据集
@ragflowController.post('/update_dataset/{dataset_id}')
async def update_dataset(
request: Request,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
return parse_result(result)
# 列出数据集中文档列表
@ragflowController.get("/list_documents/{dataset_id}")
async def list_documents_by_dataset_id(
request: Request,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
# query_db: AsyncSession = Depends(get_db),
):
"""
列出数据集中文档列表
"""
print(list_documents_query)
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
return parse_result(result)
# 上传文件到数据集
@ragflowController.post("/upload_file/{dataset_id}")
async def upload_file_dataset(
dataset_id: str,
files: List[UploadFile] = File(...),
# query_db: AsyncSession = Depends(get_db),
):
"""
上传文件到数据集
"""
# print(file)
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
return parse_result(result)
# 更新文档
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
async def update_file_dataset(
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
更新文件到数据集
"""
# print(file)
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
return parse_result(result)
# 开始解析文档
@ragflowController.post('/parse_documents/{dataset_id}')
async def parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 停止解析文档
@ragflowController.post('/stop_parse_documents/{dataset_id}')
async def stop_parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 删除文档
@ragflowController.post('/delete_file/{dataset_id}')
async def delete_file(
dataset_id: str,
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除文件
"""
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
return parse_result(result)
# 删除数据集
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除数据集
"""
result = await RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
# 查看聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel,
):
"""
查看聊天助手列表
"""
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# pass
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel,
):
"""
更新聊天助手
"""
result = await RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 创建属于聊天助手的会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel,
):
"""
创建属于聊天助手的会话
"""
result = await RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 与聊天助手进行对话
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话
"""
start_time = time.perf_counter()
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务 (Router Strategy)
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
t0 = time.perf_counter()
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
t1 = time.perf_counter()
# 修复await async方法获取AsyncGenerator
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
if converse_params.stream:
async def stream_response():
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
try:
async for chunk in result:
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
# Native RAGFlow returns cumulative text in 'answer' field
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
if isinstance(body, dict) and 'answer' in body:
# Inspect for errors even in stream
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"RAGFlow Stream Error: {chunk}")
current_answer = body['answer']
# Calculate Delta
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta # Send only the new part
last_answer = current_answer
yield format_sse(body)
else:
# Context reset (unlikely but safe fallback)
last_answer = current_answer
yield format_sse(body)
else:
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
return StreamingResponse(stream_response(), media_type='text/event-stream')
response = parse_result(result)
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
# return parse_result(result)
# 获取用户权限
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
"""
获取用户权限
"""
user_auth_list = current_user.permissions
print(user_auth_list)
return ResponseUtil.success(data=user_auth_list)
def parse_result(result):
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
"""统一结果解析"""
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
@ -317,11 +38,188 @@ def parse_result(result):
def build_chat_cache_key(chat_id: str, question: str) -> str:
"""构建聊天缓存键"""
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def format_sse(data: dict, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
# 非流式对话接口
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: Any, # 使用Any避免导入具体模型类
):
"""
与聊天助手进行对话 - 同步版本简化复杂性
"""
start_time = time.perf_counter()
# 获取redis实例
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务
try:
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
except Exception as e:
logger.warning(f"意图分类失败使用RAG服务: {e}")
# 缓存检查
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
try:
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
except Exception as e:
logger.warning(f"缓存获取失败: {e}")
# 直接使用同步RAGFlow服务
try:
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
# 流式响应
if converse_params.stream:
return StreamingResponse(
stream_ragflow_response(result, converse_params.chat_id, start_time),
media_type='text/event-stream'
)
# 非流式响应
response = parse_result(result)
# 设置缓存
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
try:
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
except Exception as e:
logger.warning(f"缓存设置失败: {e}")
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
except Exception as e:
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float) -> Generator[str, None, None]:
"""
流式处理RAGFlow响应 - 同步版本
"""
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
try:
for chunk in result:
# 检查第一个token的延迟
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_PROFILE] Time to First Token ({chat_id}): {latency:.3f}s")
# 处理chunk数据
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
# 处理错误
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"RAGFlow Stream Error: {chunk}")
yield format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
continue
# 处理answer字段的增量更新
if isinstance(body, dict) and 'answer' in body:
current_answer = body['answer']
# 计算增量内容
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta
last_answer = current_answer
yield format_sse(body)
else:
# 上下文重置的备用处理
last_answer = current_answer
yield format_sse(body)
else:
yield format_sse(body)
# 流结束
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception(f'[RAG_PROFILE] ragflow流式对话异常: {exc}')
yield format_sse({'message': str(exc)}, event='error')
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_PROFILE] Total Stream Duration ({chat_id}): {total_time:.3f}s')
# 聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: Any, # 使用Any避免导入具体模型类
):
"""获取聊天助手列表"""
result = RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# 创建会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: Any, # 使用Any避免导入具体模型类
):
"""创建聊天会话"""
result = RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: Any, # 使用Any避免导入具体模型类
):
"""更新聊天助手"""
result = RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 数据集相关接口
@ragflowController.post('/dataset_list')
async def get_system_ragflow_list(
request: Request,
ragflow_list_query: Any, # 使用Any避免导入具体模型类
):
"""获取数据集列表"""
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
return parse_result(result)
@ragflowController.post('/create_dataset')
async def create_dataset(
create_dataset_params: Any, # 使用Any避免导入具体模型类
):
"""创建数据集"""
result = RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: Any, # 使用Any避免导入具体模型类
):
"""删除数据集"""
result = RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)

View File

@ -0,0 +1,376 @@
# from datetime import datetime
import hashlib
import json
import time
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import StreamingResponse
# from pydantic_validation_decorator import ValidateFields
# from sqlalchemy.ext.asyncio import AsyncSession
# from config.enums import BusinessType
# from config.get_db import get_db
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
# from module_admin.entity.vo.user_vo import CurrentUserModel
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService, SearchServiceError
from utils.log_util import logger
# from utils.page_util import PageResponseModel
from utils.response_util import ResponseUtil
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
CreateSessionWithChatModel, ConverseWithChatAssistantModel
# from config.env import RAGFlowConfig
import asyncio
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
# 查看数据集列表
@ragflowController.post("/dataset_list"
# , response_model=PageResponseModel
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
)
async def get_system_ragflow_list(
request: Request,
rage_flow_dastset_query: RagflowListQueryModel ,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
return parse_result(result)
# 创建数据集
@ragflowController.post('/create_dataset')
async def create_dataset(
request: Request,
create_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
# 更新数据集
@ragflowController.post('/update_dataset/{dataset_id}')
async def update_dataset(
request: Request,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
return parse_result(result)
# 列出数据集中文档列表
@ragflowController.get("/list_documents/{dataset_id}")
async def list_documents_by_dataset_id(
request: Request,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
# query_db: AsyncSession = Depends(get_db),
):
"""
列出数据集中文档列表
"""
print(list_documents_query)
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
return parse_result(result)
# 上传文件到数据集
@ragflowController.post("/upload_file/{dataset_id}")
async def upload_file_dataset(
dataset_id: str,
files: List[UploadFile] = File(...),
# query_db: AsyncSession = Depends(get_db),
):
"""
上传文件到数据集
"""
# print(file)
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
return parse_result(result)
# 更新文档
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
async def update_file_dataset(
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
更新文件到数据集
"""
# print(file)
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
return parse_result(result)
# 开始解析文档
@ragflowController.post('/parse_documents/{dataset_id}')
async def parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 停止解析文档
@ragflowController.post('/stop_parse_documents/{dataset_id}')
async def stop_parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 删除文档
@ragflowController.post('/delete_file/{dataset_id}')
async def delete_file(
dataset_id: str,
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除文件
"""
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
return parse_result(result)
# 删除数据集
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除数据集
"""
result = await RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
# 查看聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel,
):
"""
查看聊天助手列表
"""
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# pass
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel,
):
"""
更新聊天助手
"""
result = await RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 创建属于聊天助手的会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel,
):
"""
创建属于聊天助手的会话
"""
result = await RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 与聊天助手进行对话
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话
"""
start_time = time.perf_counter()
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务 (Router Strategy)
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
t0 = time.perf_counter()
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
t1 = time.perf_counter()
# 使用同步Generator更简单
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
if converse_params.stream:
async def stream_response():
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
try:
# 将同步Generator转换为异步迭代
loop = asyncio.get_event_loop()
for chunk in result:
# 在事件循环中运行CPU密集型操作
await asyncio.sleep(0) # 让出控制权,允许其他协程运行
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
# Native RAGFlow returns cumulative text in 'answer' field
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
if isinstance(body, dict) and 'answer' in body:
# Inspect for errors even in stream
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"RAGFlow Stream Error: {chunk}")
current_answer = body['answer']
# Calculate Delta
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta # Send only the new part
last_answer = current_answer
yield format_sse(body)
else:
# Context reset (unlikely but safe fallback)
last_answer = current_answer
yield format_sse(body)
else:
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
return StreamingResponse(stream_response(), media_type='text/event-stream')
response = parse_result(result)
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
# return parse_result(result)
# 获取用户权限
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
"""
获取用户权限
"""
user_auth_list = current_user.permissions
print(user_auth_list)
return ResponseUtil.success(data=user_auth_list)
def parse_result(result):
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def format_sse(data: dict, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
async def aiter_sync_generator(sync_gen):
"""将同步 Generator 转换为异步迭代器"""
loop = asyncio.get_event_loop()
try:
while True:
# 在线程池中获取下一个值
value = await loop.run_in_executor(None, next, sync_gen, None)
if value is None:
break
yield value
except StopIteration:
pass
import asyncio
import hashlib
import json
import time
from typing import AsyncIterator, Union
import aiofiles
import redis
from fastapi import Depends, Request, Response
from fastapi.responses import StreamingResponse
from fastapi_controller.decorators import Controller
from fastapi_controller.dependencies import LoginService
from fastapi_controller.response import ResponseUtil
from sqlalchemy.ext.asyncio import AsyncSession
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService
from module_admin.service.login_service import CheckRoleInterfaceAuth
from module_admin.model.ragflow_models import (
ConverseWithChatAssistantModel,
CreateSessionWithChatModel,
DeleteFileModel,
DocumentIdsModel,
RagflowListQueryModel,
UpdateChatAssistantModel,
)
import logger as logger_module

View File

@ -1,196 +1,154 @@
from sqlalchemy.ext.asyncio import AsyncSession
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel \
,CreateSessionWithChatModel, ConverseWithChatAssistantModel
from module_admin.entity.vo.ragflow_vo import (
RagflowListQueryModel,
ListDocumentsQueryModel,
UpdateFileModel,
DeleteFileModel,
CreateDatasetModel,
DocumentIdsModel,
UpdateChatAssistantModel,
CreateSessionWithChatModel,
ConverseWithChatAssistantModel
)
from utils.ragflow_util import RAGFlowClient
from config.env import RAGFlowConfig
from utils.ragflow_client_manager import get_ragflow_client
class RAGFlowService:
"""
RAGFlow服务
RAGFlow服务 - 简化版本使用同步操作
"""
# 获取数据集列表
@classmethod
async def get_ragflow_dataset_list_services(cls, query_db: AsyncSession, rage_flow_query: RagflowListQueryModel):
def get_ragflow_dataset_list_services(cls, query_db, rage_flow_query: RagflowListQueryModel):
"""
获取数据集列表
获取数据集列表 - 同步版本
"""
client = await get_ragflow_client()
result = await client.list_datasets(**(rage_flow_query.model_dump()))
# 获取分页数据
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.list_datasets(**(rage_flow_query.model_dump()))
return result
# 创建数据集
@classmethod
async def create_dataset_services(cls, create_dataset_params: CreateDatasetModel):
"""创建数据集
Args:
create_dataset_params (CreateDatasetModel): 创建参数
Returns:
_type_: _description_
"""
client = await get_ragflow_client()
result = await client.create_dataset(
**(create_dataset_params.model_dump())
)
def create_dataset_services(cls, create_dataset_params: CreateDatasetModel):
"""创建数据集 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.create_dataset(**(create_dataset_params.model_dump()))
return result
# 更新数据集
@classmethod
async def update_dataset_services(
cls,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
"""更新数据集信息
Args:
dataset_id (str): 数据集id
update_dataset_params (CreateDatasetModel): 更新参数
Returns:
_type_: _description_
"""
client = await get_ragflow_client()
result = await client.update_dataset(
dataset_id=dataset_id, **(update_dataset_params.model_dump())
)
def update_dataset_services(cls, dataset_id: str, update_dataset_params: CreateDatasetModel):
"""更新数据集信息 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.update_dataset(dataset_id=dataset_id, **(update_dataset_params.model_dump()))
return result
# 获取数据集中文档列表
@classmethod
async def list_documents_services(
cls,
query_db: AsyncSession,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel,
):
client = await get_ragflow_client()
result = await client.list_documents(dataset_id=dataset_id, **(list_documents_query.model_dump()))
def list_documents_services(cls, query_db, dataset_id: str, list_documents_query: ListDocumentsQueryModel):
"""获取文档列表 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.list_documents(dataset_id=dataset_id, **(list_documents_query.model_dump()))
return result
# 上传文档到数据集
@classmethod
async def upload_file_dataset_services(
cls,
dataset_id: str,
files,
):
client = await get_ragflow_client()
result = await client.upload_documents_bytes(dataset_id=dataset_id, file_bytes=files)
def upload_file_dataset_services(cls, dataset_id: str, files):
"""上传文档 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.upload_documents_bytes(dataset_id=dataset_id, file_bytes=files)
return result
# 开始解析文档
@classmethod
async def parse_documents_services(
cls,
dataset_id: str,
parse_params: DocumentIdsModel,
):
client = await get_ragflow_client()
result = await client.parse_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
def parse_documents_services(cls, dataset_id: str, parse_params: DocumentIdsModel):
"""解析文档 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.parse_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
return result
# 停止解析文档
@classmethod
async def stop_parse_documents_services(
cls,
dataset_id: str,
parse_params: DocumentIdsModel,
):
client = await get_ragflow_client()
result = await client.stop_parsing_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
def stop_parse_documents_services(cls, dataset_id: str, parse_params: DocumentIdsModel):
"""停止解析文档 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.stop_parsing_documents(dataset_id=dataset_id, document_ids=parse_params.documnet_ids)
return result
# 更新文档内容
@classmethod
async def update_file_dataset_services(
cls,
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
):
client = await get_ragflow_client()
result = await client.update_document(dataset_id=dataset_id, document_id=document_id, **(update_params.model_dump()))
def update_file_dataset_services(cls, dataset_id: str, document_id: str, update_params: UpdateFileModel):
"""更新文档 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.update_document(dataset_id=dataset_id, document_id=document_id, **(update_params.model_dump()))
return result
# 删除文档
@classmethod
async def delete_file_services(
cls,
dataset_id: str,
delete_params: DeleteFileModel,
):
client = await get_ragflow_client()
result = await client.delete_documents(dataset_id=dataset_id, **(delete_params.model_dump()))
def delete_file_services(cls, dataset_id: str, delete_params: DeleteFileModel):
"""删除文档 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.delete_documents(dataset_id=dataset_id, **(delete_params.model_dump()))
return result
# 删除数据集
@classmethod
async def delete_datasets_services(
cls,
delete_params: DeleteFileModel,
):
client = await get_ragflow_client()
result = await client.delete_datasets(**(delete_params.model_dump()))
def delete_datasets_services(cls, delete_params: DeleteFileModel):
"""删除数据集 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.delete_datasets(**(delete_params.model_dump()))
return result
# 查看聊天助手列表
@classmethod
async def get_chat_assistant_list_services(
cls,
query_params: RagflowListQueryModel,
):
client = await get_ragflow_client()
result = await client.list_chat_assistants(**(query_params.model_dump()))
def get_chat_assistant_list_services(cls, query_params: RagflowListQueryModel):
"""获取聊天助手列表 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.list_chat_assistants(**(query_params.model_dump()))
return result
# 修改聊天助手
@classmethod
async def update_chat_assistant_services(cls, update_params: UpdateChatAssistantModel):
client = await get_ragflow_client()
result = await client.update_chat_assistant(**(update_params.model_dump()))
def update_chat_assistant_services(cls, update_params: UpdateChatAssistantModel):
"""更新聊天助手 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.update_chat_assistant(**(update_params.model_dump()))
return result
# 创建助手会话
@classmethod
async def create_session_with_chat_services(cls, create_params: CreateSessionWithChatModel):
client = await get_ragflow_client()
result = await client.create_session_with_chat(**(create_params.model_dump()))
def create_session_with_chat_services(cls, create_params: CreateSessionWithChatModel):
"""创建会话 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
result = client.create_session_with_chat(**(create_params.model_dump()))
return result
# 与助手聊天
# 与助手聊天 - 核心方法
@classmethod
async def converse_with_chat_assistant_services(cls, converse_params: ConverseWithChatAssistantModel):
client = await get_ragflow_client()
# 修复直接返回AsyncGenerator不使用await消费流式数据
return client.converse_with_chat_assistant(**(converse_params.model_dump()))
def converse_with_chat_assistant_services(cls, converse_params: ConverseWithChatAssistantModel):
"""与聊天助手对话 - 同步版本返回Generator"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
# 直接返回Generator支持流式响应
return client.converse_with_chat_assistant(
chat_id=converse_params.chat_id,
question=converse_params.question,
stream=converse_params.stream,
session_id=converse_params.session_id
)
# 与助手聊天 (OpenAI Compatible)
# 与助手聊天 (OpenAI Compatible) - 同步版本
@classmethod
async def converse_with_chat_assistant_services_openai(cls, converse_params: ConverseWithChatAssistantModel):
client = await get_ragflow_client()
def converse_with_chat_assistant_services_openai(cls, converse_params: ConverseWithChatAssistantModel):
"""OpenAI兼容格式 - 同步版本"""
client = RAGFlowClient(RAGFlowConfig.RAGFLOW_BASE_URL, RAGFlowConfig.RAGFLOW_API_KEY)
# Construct messages list for OpenAI format
messages = [{"role": "user", "content": converse_params.question}]
# Uses defaults for model name as per user indication "server will parse this automatically"
return await client.create_chat_completion(
return client.create_chat_completion(
chat_id=converse_params.chat_id,
model="ragflow",
messages=messages,
stream=converse_params.stream
)
)

View File

@ -189,8 +189,8 @@ class SearchService:
except SearchServiceError as exc:
logger.warning('搜索服务失败降级到RAGFlow: %s', exc)
# Fallback to normal flow
# 修复:await async方法获取AsyncGenerator
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
# 修复:直接调用同步方法,不再使用await
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
if converse_params.stream:
async def stream_response():

View File

@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
简化后的RAGFlow实现测试脚本
验证同步Generator在异步环境中的工作
"""
import asyncio
import time
from typing import Generator, Any
# 模拟同步RAGFlow客户端
class MockSyncRAGFlowClient:
"""模拟同步RAGFlow客户端返回Generator"""
def __init__(self, base_url: str, api_key: str):
self.base_url = base_url
self.api_key = api_key
def converse_with_chat_assistant(self, **kwargs) -> Generator[dict, None, None]:
"""模拟流式对话返回同步Generator"""
print(f"Mock client: 开始生成流式数据...")
# 模拟流式数据生成
responses = [
{'data': {'answer': 'Hello'}},
{'data': {'answer': 'Hello, I am'}},
{'data': {'answer': 'Hello, I am an AI'}},
{'data': {'answer': 'Hello, I am an AI assistant'}},
{'data': {'answer': 'Hello, I am an AI assistant.'}},
]
for i, response in enumerate(responses):
print(f"Mock client: 生成第{i+1}个数据块")
time.sleep(0.5) # 模拟网络延迟
yield response
print(f"Mock client: 流式数据生成完成")
# 模拟同步RAGFlowService
class MockRAGFlowService:
"""模拟同步RAGFlowService"""
@staticmethod
def converse_with_chat_assistant_services(converse_params) -> Generator[dict, None, None]:
"""返回同步Generator"""
print("MockService: 调用converse_with_chat_assistant_services")
client = MockSyncRAGFlowClient("http://localhost:9099", "test_key")
return client.converse_with_chat_assistant(
chat_id=converse_params.get('chat_id'),
question=converse_params.get('question'),
stream=True,
session_id=converse_params.get('session_id')
)
# 模拟异步控制器
async def async_controller_test():
"""测试异步控制器中消费同步Generator"""
# 模拟参数
params = {
'chat_id': 'test_chat_123',
'question': '你好,请介绍一下自己',
'stream': True,
'session_id': 'session_456'
}
print("=" * 60)
print("测试异步控制器消费同步Generator")
print("=" * 60)
start_time = time.perf_counter()
first_token_received = False
try:
# 调用服务层(同步方法)
print("1. 调用RAGFlowService.converse_with_chat_assistant_services...")
result = MockRAGFlowService.converse_with_chat_assistant_services(params)
print(f" 返回类型: {type(result)}")
if not isinstance(result, Generator):
raise TypeError(f"期望Generator类型但得到 {type(result)}")
# 在异步上下文中消费同步Generator
print("2. 开始消费同步Generator...")
chunk_count = 0
try:
for chunk in result:
chunk_count += 1
print(f" 接收到第{chunk_count}个数据块: {chunk}")
# 检查第一个token延迟
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_time
print(f" 首Token延迟: {latency:.3f}s")
# 模拟处理每个chunk
await asyncio.sleep(0.01) # 让出控制权
except Exception as e:
print(f" 消费数据时出错: {e}")
raise
total_time = time.perf_counter() - start_time
print(f"3. 流式处理完成,总耗时: {total_time:.3f}s")
print(f" 总共接收数据块: {chunk_count}")
# 验证结果
if chunk_count == 5:
print("✅ 测试通过成功接收到所有5个数据块")
else:
print(f"❌ 测试失败期望5个数据块实际收到{chunk_count}")
except Exception as e:
print(f"❌ 测试失败:{e}")
import traceback
traceback.print_exc()
# 测试同步消费 vs 异步消费
def sync_vs_async_test():
"""测试同步消费和异步消费的差异"""
print("\n" + "=" * 60)
print("测试:同步消费 vs 异步消费")
print("=" * 60)
# 创建同步Generator
def sync_generator():
for i in range(5):
time.sleep(0.1)
yield f"数据块 {i+1}"
generator = sync_generator()
# 1. 同步消费
print("1. 同步消费测试:")
start_time = time.perf_counter()
for item in generator:
print(f" {item}")
sync_time = time.perf_counter() - start_time
print(f" 同步消费耗时: {sync_time:.3f}s")
# 2. 异步消费
print("\n2. 异步消费测试:")
generator2 = sync_generator()
start_time = time.perf_counter()
async def async_consumer(gen):
count = 0
for item in gen:
count += 1
print(f" {item}")
await asyncio.sleep(0.01) # 让出控制权
return count
async def run_async_test():
return await async_consumer(generator2)
try:
count = asyncio.run(run_async_test())
async_time = time.perf_counter() - start_time
print(f" 异步消费耗时: {async_time:.3f}s")
print(f" 处理了{count}个项目")
except Exception as e:
print(f" 异步消费失败: {e}")
async def main():
"""主测试函数"""
print("RAGFlow简化实现测试")
print("=" * 60)
# 运行主要测试
await async_controller_test()
# 运行对比测试
sync_vs_async_test()
print("\n" + "=" * 60)
print("测试总结:")
print("1. 同步Generator可以在异步环境中正常工作")
print("2. 使用for循环可以自动处理同步Generator")
print("3. 异步消费需要适当让出控制权(await)")
print("4. 简化后的架构避免了复杂的async/await链")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())