kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py

262 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
简化的RAGFlow控制器 - 移除异步复杂性
"""
import hashlib
import json
import time
from typing import Any, Dict, Generator
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService
from module_admin.entity.vo.ragflow_vo import (
ConverseWithChatAssistantModel,
CreateSessionWithChatModel,
UpdateChatAssistantModel,
RagflowListQueryModel
)
from utils.log_util import logger
from utils.response_util import ResponseUtil
# 使用标准的APIRouter增加认证依赖
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
def format_sse(data: dict, event: str | None = None) -> str:
"""格式化SSE数据"""
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
"""统一结果解析"""
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
"""构建聊天缓存键"""
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
# 非流式对话接口
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型
):
"""
与聊天助手进行对话 - 同步版本,简化复杂性
"""
start_time = time.perf_counter()
# 获取redis实例
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务
try:
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
except Exception as e:
logger.warning(f"意图分类失败使用RAG服务: {e}")
# 缓存检查
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
try:
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
except Exception as e:
logger.warning(f"缓存获取失败: {e}")
# 直接使用同步RAGFlow服务
try:
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
# 流式响应
if converse_params.stream:
return StreamingResponse(
stream_ragflow_response(result, converse_params.chat_id, start_time),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Transfer-Encoding': 'chunked'
}
)
# 非流式响应
response = parse_result(result)
# 设置缓存
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
try:
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
except Exception as e:
logger.warning(f"缓存设置失败: {e}")
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
except Exception as e:
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float) -> Generator[str, None, None]:
"""
流式处理RAGFlow响应 - 同步版本修复首token延迟
"""
import time
server_stream_start = time.time()
logger.info(f"[RAG_SERVER {server_stream_start:.3f}] 🚀 开始流式响应处理chat_id: {chat_id}")
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
# 立即发送连接建立消息解决首token延迟
connection_time = time.time()
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
yield ping_message
logger.info(f"[RAG_SERVER {connection_time:.3f}] 📡 连接建立ping发送耗时: {connection_time - server_stream_start:.3f}s")
chunk_count = 0
try:
for chunk in result:
chunk_time = time.time()
chunk_count += 1
# 检查第一个token的延迟
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🎯 首token到达耗时: {latency:.3f}s")
# 处理chunk数据
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ 空chunk跳过")
continue
body = payload if isinstance(payload, dict) else {'data': payload}
# 处理错误
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"[RAG_SERVER {chunk_time:.3f}] ❌ RAGFlow Stream Error: {chunk}")
error_message = format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
yield error_message
continue
# 处理answer字段的增量更新
if isinstance(body, dict) and 'answer' in body:
current_answer = body['answer']
# 计算增量内容
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📝 Chunk #{chunk_count} 处理完成delta长度: {len(delta)}")
yield format_sse(body)
else:
# 上下文重置的备用处理
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🔄 Chunk #{chunk_count} 上下文重置")
yield format_sse(body)
else:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📦 Chunk #{chunk_count} 其他数据: {body}")
yield format_sse(body)
# 流结束
stream_end_time = time.time()
end_message = format_sse({'status': 'completed'}, event='end')
yield end_message
logger.info(f"[RAG_SERVER {stream_end_time:.3f}] 🏁 流式响应完成chat_id: {chat_id}")
logger.info(f"[RAG_SERVER {stream_end_time:.3f}] 📊 总共处理chunk数量: {chunk_count}")
except Exception as exc:
error_time = time.time()
logger.exception(f'[RAG_SERVER {error_time:.3f}] ragflow流式对话异常: {exc}')
error_message = format_sse({'message': str(exc)}, event='error')
yield error_message
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s')
# 聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel, # 使用正确的Pydantic模型
):
"""获取聊天助手列表"""
result = RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# 创建会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型
):
"""创建聊天会话"""
result = RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型
):
"""更新聊天助手"""
result = RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 数据集相关接口
@ragflowController.post('/dataset_list')
async def get_system_ragflow_list(
request: Request,
ragflow_list_query: Any, # 使用Any避免导入具体模型类
):
"""获取数据集列表"""
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
return parse_result(result)
@ragflowController.post('/create_dataset')
async def create_dataset(
create_dataset_params: Any, # 使用Any避免导入具体模型类
):
"""创建数据集"""
result = RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: Any, # 使用Any避免导入具体模型类
):
"""删除数据集"""
result = RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)