239 lines
8.5 KiB
Python
239 lines
8.5 KiB
Python
"""
|
||
简化的RAGFlow控制器 - 移除异步复杂性
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from typing import Any, Dict, Generator
|
||
|
||
from fastapi import APIRouter, Depends, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from module_admin.service.login_service import LoginService
|
||
from module_admin.service.ragflow_service import RAGFlowService
|
||
from module_admin.service.search_service import SearchService
|
||
from module_admin.entity.vo.ragflow_vo import (
|
||
ConverseWithChatAssistantModel,
|
||
CreateSessionWithChatModel,
|
||
UpdateChatAssistantModel,
|
||
RagflowListQueryModel
|
||
)
|
||
from utils.log_util import logger
|
||
from utils.response_util import ResponseUtil
|
||
|
||
# 使用标准的APIRouter
|
||
ragflowController = APIRouter(prefix="/system/ragflow")
|
||
|
||
|
||
def format_sse(data: dict, event: str | None = None) -> str:
|
||
"""格式化SSE数据"""
|
||
payload = json.dumps(data, ensure_ascii=False)
|
||
prefix = f'event: {event}\n' if event else ''
|
||
return f'{prefix}data: {payload}\n\n'
|
||
|
||
|
||
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""统一结果解析"""
|
||
code = result.get('code', 0)
|
||
if code != 0:
|
||
msg = result.get('message') or result.get('msg') or '接口异常'
|
||
return ResponseUtil.error(msg=msg, data=result.get('data', None))
|
||
return ResponseUtil.success(data=result.get('data', None))
|
||
|
||
|
||
def build_chat_cache_key(chat_id: str, question: str) -> str:
|
||
"""构建聊天缓存键"""
|
||
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
|
||
return f'ragflow:chat:{chat_id}:{digest}'
|
||
|
||
|
||
# 非流式对话接口
|
||
@ragflowController.post('/converse_with_chat_assistant')
|
||
async def converse_with_chat_assistant(
|
||
request: Request,
|
||
converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""
|
||
与聊天助手进行对话 - 同步版本,简化复杂性
|
||
"""
|
||
start_time = time.perf_counter()
|
||
|
||
# 获取redis实例
|
||
redis = getattr(request.app.state, 'redis', None)
|
||
cache_key = None
|
||
|
||
# 检查是否应该使用搜索服务
|
||
try:
|
||
intent = await SearchService.classify_intent(converse_params.question)
|
||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
|
||
|
||
if intent == 'SEARCH':
|
||
return await SearchService.handle_search_chat(converse_params, redis)
|
||
except Exception as e:
|
||
logger.warning(f"意图分类失败,使用RAG服务: {e}")
|
||
|
||
# 缓存检查
|
||
if not converse_params.stream and redis:
|
||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||
try:
|
||
cached = await redis.get(cache_key)
|
||
if cached:
|
||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||
return ResponseUtil.success(json.loads(cached))
|
||
except Exception as e:
|
||
logger.warning(f"缓存获取失败: {e}")
|
||
|
||
# 直接使用同步RAGFlow服务
|
||
try:
|
||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||
|
||
# 流式响应
|
||
if converse_params.stream:
|
||
return StreamingResponse(
|
||
stream_ragflow_response(result, converse_params.chat_id, start_time),
|
||
media_type='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
'X-Accel-Buffering': 'no',
|
||
'Transfer-Encoding': 'chunked'
|
||
}
|
||
)
|
||
|
||
# 非流式响应
|
||
response = parse_result(result)
|
||
|
||
# 设置缓存
|
||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||
try:
|
||
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
|
||
except Exception as e:
|
||
logger.warning(f"缓存设置失败: {e}")
|
||
|
||
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
|
||
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
|
||
|
||
|
||
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float) -> Generator[str, None, None]:
|
||
"""
|
||
流式处理RAGFlow响应 - 同步版本,修复首token延迟
|
||
"""
|
||
last_answer = ""
|
||
first_token_received = False
|
||
start_stream_time = time.perf_counter()
|
||
|
||
# 立即发送连接建立消息,解决首token延迟
|
||
yield "event: ping\ndata: {\"status\": \"connected\"}\n\n"
|
||
|
||
try:
|
||
for chunk in result:
|
||
# 检查第一个token的延迟
|
||
if not first_token_received:
|
||
first_token_received = True
|
||
latency = time.perf_counter() - start_stream_time
|
||
logger.info(f"[RAG_PROFILE] Time to First Token ({chat_id}): {latency:.3f}s")
|
||
|
||
# 处理chunk数据
|
||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||
if not payload:
|
||
continue
|
||
|
||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||
|
||
# 处理错误
|
||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||
logger.error(f"RAGFlow Stream Error: {chunk}")
|
||
yield format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
|
||
continue
|
||
|
||
# 处理answer字段的增量更新
|
||
if isinstance(body, dict) and 'answer' in body:
|
||
current_answer = body['answer']
|
||
|
||
# 计算增量内容
|
||
if current_answer.startswith(last_answer):
|
||
delta = current_answer[len(last_answer):]
|
||
if delta:
|
||
body['answer'] = delta
|
||
last_answer = current_answer
|
||
yield format_sse(body)
|
||
else:
|
||
# 上下文重置的备用处理
|
||
last_answer = current_answer
|
||
yield format_sse(body)
|
||
else:
|
||
yield format_sse(body)
|
||
|
||
# 流结束
|
||
yield format_sse({'status': 'completed'}, event='end')
|
||
|
||
except Exception as exc:
|
||
logger.exception(f'[RAG_PROFILE] ragflow流式对话异常: {exc}')
|
||
yield format_sse({'message': str(exc)}, event='error')
|
||
finally:
|
||
total_time = time.perf_counter() - start_time
|
||
logger.info(f'[RAG_PROFILE] Total Stream Duration ({chat_id}): {total_time:.3f}s')
|
||
|
||
|
||
# 聊天助手列表
|
||
@ragflowController.post('/get_chat_assistant_list')
|
||
async def get_chat_assistant_list(
|
||
query_params: RagflowListQueryModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""获取聊天助手列表"""
|
||
result = RAGFlowService.get_chat_assistant_list_services(query_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 创建会话
|
||
@ragflowController.post('/create_session_with_chat')
|
||
async def create_session_with_chat(
|
||
create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""创建聊天会话"""
|
||
result = RAGFlowService.create_session_with_chat_services(create_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 更新聊天助手
|
||
@ragflowController.post('/update_chat_assistant')
|
||
async def update_chat_assistant(
|
||
update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""更新聊天助手"""
|
||
result = RAGFlowService.update_chat_assistant_services(update_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 数据集相关接口
|
||
@ragflowController.post('/dataset_list')
|
||
async def get_system_ragflow_list(
|
||
request: Request,
|
||
ragflow_list_query: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""获取数据集列表"""
|
||
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
|
||
return parse_result(result)
|
||
|
||
|
||
@ragflowController.post('/create_dataset')
|
||
async def create_dataset(
|
||
create_dataset_params: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""创建数据集"""
|
||
result = RAGFlowService.create_dataset_services(create_dataset_params)
|
||
return parse_result(result)
|
||
|
||
|
||
@ragflowController.post('/delete_datasets')
|
||
async def delete_datasets(
|
||
delete_params: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""删除数据集"""
|
||
result = RAGFlowService.delete_datasets_services(delete_params)
|
||
return parse_result(result) |