kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py

313 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
简化的RAGFlow控制器 - 集成语义缓存
"""
import asyncio
import hashlib
import json
import time
from typing import Any, Dict, Generator, Optional, Tuple
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService
from module_admin.entity.vo.ragflow_vo import (
ConverseWithChatAssistantModel,
CreateSessionWithChatModel,
UpdateChatAssistantModel,
RagflowListQueryModel
)
from utils.log_util import logger
from utils.response_util import ResponseUtil
from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair
async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None:
"""
异步存储问答对到语义缓存
"""
try:
await store_qa_pair(chat_id, question, answer, redis)
except Exception as e:
logger.warning(f"[SemanticCache] 异步存储失败: {e}")
# 使用标准的APIRouter增加认证依赖
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
def format_sse(data: dict, event: str | None = None) -> str:
"""格式化SSE数据"""
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
"""统一结果解析"""
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
"""构建聊天缓存键"""
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
# 非流式对话接口
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型
):
"""
与聊天助手进行对话 - 集成语义缓存版本
"""
start_time = time.perf_counter()
# 获取redis实例
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 语义缓存查找(非流式对话)
if not converse_params.stream and redis:
cache_result = await lookup_question(
converse_params.chat_id,
converse_params.question,
redis
)
if cache_result:
cached_answer, similarity = cache_result
logger.info(f'[SemanticCache] 命中缓存 (相似度={similarity:.2f}): chat_id={converse_params.chat_id}')
# 返回缓存的答案,包装成标准响应格式
return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': similarity})
# 检查是否应该使用搜索服务
try:
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
except Exception as e:
logger.warning(f"意图分类失败使用RAG服务: {e}")
# 缓存检查(原有的精确缓存)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
try:
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
except Exception as e:
logger.warning(f"缓存获取失败: {e}")
# 直接使用同步RAGFlow服务
try:
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
# 流式响应
if converse_params.stream:
return StreamingResponse(
stream_ragflow_response(result, converse_params.chat_id, start_time),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Transfer-Encoding': 'chunked'
}
)
# 非流式响应
response = parse_result(result)
# 设置语义缓存(存储问答对)
if not converse_params.stream and redis and isinstance(result, dict) and result.get('code') == 0:
try:
answer_data = result.get('data', {})
answer_text = answer_data.get('answer') if isinstance(answer_data, dict) else str(answer_data)
if answer_text and len(answer_text.strip()) >= 10:
# 异步存储,不阻塞响应返回
asyncio.create_task(_async_store_qa(
converse_params.chat_id,
converse_params.question,
answer_text,
redis
))
except Exception as e:
logger.warning(f"语义缓存存储失败: {e}")
# 设置原有精确缓存
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
try:
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
except Exception as e:
logger.warning(f"缓存设置失败: {e}")
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
except Exception as e:
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float) -> Generator[str, None, None]:
"""
流式处理RAGFlow响应 - 同步版本修复首token延迟
"""
import time
server_stream_start = time.time()
logger.info(f"[RAG_SERVER {server_stream_start:.3f}] 🚀 开始流式响应处理chat_id: {chat_id}")
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
# 立即发送连接建立消息解决首token延迟
connection_time = time.time()
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
yield ping_message
logger.info(f"[RAG_SERVER {connection_time:.3f}] 📡 连接建立ping发送耗时: {connection_time - server_stream_start:.3f}s")
chunk_count = 0
try:
for chunk in result:
chunk_time = time.time()
chunk_count += 1
# 检查第一个token的延迟
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🎯 首token到达耗时: {latency:.3f}s")
# 处理chunk数据
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ 空chunk跳过")
continue
body = payload if isinstance(payload, dict) else {'data': payload}
# 处理错误
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"[RAG_SERVER {chunk_time:.3f}] ❌ RAGFlow Stream Error: {chunk}")
error_message = format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
yield error_message
continue
# 处理answer字段的增量更新
if isinstance(body, dict) and 'answer' in body:
current_answer = body['answer']
# 安全检查确保current_answer不为None且为字符串
if current_answer is None:
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer为None跳过处理")
continue
if not isinstance(current_answer, str):
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer不是字符串类型: {type(current_answer)}")
current_answer = str(current_answer)
# 计算增量内容
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📝 Chunk #{chunk_count} 处理完成delta长度: {len(delta)}")
yield format_sse(body)
else:
# 上下文重置的备用处理
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🔄 Chunk #{chunk_count} 上下文重置")
yield format_sse(body)
else:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📦 Chunk #{chunk_count} 其他数据: {body}")
yield format_sse(body)
# 流结束
stream_end_time = time.time()
end_message = format_sse({'status': 'completed'}, event='end')
yield end_message
logger.info(f"[RAG_SERVER {stream_end_time:.3f}] 🏁 流式响应完成chat_id: {chat_id}")
logger.info(f"[RAG_SERVER {stream_end_time:.3f}] 📊 总共处理chunk数量: {chunk_count}")
except Exception as exc:
error_time = time.time()
logger.exception(f'[RAG_SERVER {error_time:.3f}] ragflow流式对话异常: {exc}')
error_message = format_sse({'message': str(exc)}, event='error')
yield error_message
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s')
# 聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel, # 使用正确的Pydantic模型
):
"""获取聊天助手列表"""
result = RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# 创建会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型
):
"""创建聊天会话"""
result = RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型
):
"""更新聊天助手"""
result = RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 数据集相关接口
@ragflowController.post('/dataset_list')
async def get_system_ragflow_list(
request: Request,
ragflow_list_query: Any, # 使用Any避免导入具体模型类
):
"""获取数据集列表"""
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
return parse_result(result)
@ragflowController.post('/create_dataset')
async def create_dataset(
create_dataset_params: Any, # 使用Any避免导入具体模型类
):
"""创建数据集"""
result = RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: Any, # 使用Any避免导入具体模型类
):
"""删除数据集"""
result = RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)