kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py
2025-12-24 19:42:18 +08:00

463 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
简化的RAGFlow控制器 - 集成语义缓存
"""
import asyncio
import hashlib
import json
import time
from typing import Any, Dict, Generator, Optional, Tuple
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService
from module_admin.entity.vo.ragflow_vo import (
ConverseWithChatAssistantModel,
CreateSessionWithChatModel,
UpdateChatAssistantModel,
RagflowListQueryModel
)
from utils.log_util import logger
from utils.response_util import ResponseUtil
from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair
def _get_question_hash(question: str) -> str:
"""计算问题的hash值"""
import hashlib
normalized = question.lower().strip()
return hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
from utils.static_qa_service import get_static_qa_service
async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None:
"""
异步存储问答对到语义缓存
"""
store_hash = _get_question_hash(question)
logger.info(f"[SemanticCache] 存储QA | chat_id={chat_id} | question={question} | hash={store_hash} | answer_length={len(answer)}")
try:
await store_qa_pair(chat_id, question, answer, redis)
except Exception as e:
logger.warning(f"[SemanticCache] 异步存储失败: {e}")
# 使用标准的APIRouter增加认证依赖
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
def format_sse(data: dict, event: str | None = None) -> str:
"""格式化SSE数据"""
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
"""统一结果解析"""
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
"""构建聊天缓存键"""
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def remove_style_hint(question: str) -> Tuple[str, bool]:
"""
移除语言风格提示词
Args:
question: 原始问题
Returns:
(处理后的问题, 是否移除了风格提示词)
"""
if not question:
return question, False
import re
pattern = r'语言风格\s*[:]\s*[^,。!?\n]*'
match = re.search(pattern, question)
if match:
style_hint = match.group(0)
cleaned_question = re.sub(pattern, '', question).strip()
cleaned_question = re.sub(r'\s+', ' ', cleaned_question)
logger.info(f'[StyleHint] 移除风格提示词: {style_hint} | 原始问题: {question} | 处理后: {cleaned_question}')
return cleaned_question, True
return question, False
# 非流式对话接口
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话 - 集成语义缓存版本(支持流式和非流式)
匹配流程:
1. 移除语言风格提示词
2. 静态问答匹配 (threshold=0.70)
3. RAG历史缓存匹配 (threshold=0.60)
4. RAG服务调用使用原始问题
"""
start_time = time.perf_counter()
original_question = converse_params.question
cleaned_question, style_removed = remove_style_hint(original_question)
if style_removed:
converse_params.question = cleaned_question
# 获取redis实例
redis = getattr(request.app.state, 'redis', None)
cache_key = None
cached_answer = None
cache_similarity = 0.0
cache_source = None
# ========== 1. 静态问答匹配 ==========
try:
static_qa_service = get_static_qa_service()
logger.info(f'[StaticQA] 开始匹配 | chat_id={converse_params.chat_id} | question={converse_params.question} | threshold=0.70')
static_match, static_sim = static_qa_service.find_match(converse_params.question, threshold=0.70)
logger.info(f'[StaticQA] 匹配结果 | similarity={static_sim:.2f} | matched={static_match is not None}')
if static_match:
cached_answer = static_match.get('answer', '')
cache_similarity = static_sim
cache_source = 'static_qa'
logger.info(f'[RAG_SOURCE] 命中静态FAQ | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}')
logger.info(f'[StaticQA] 流式响应使用静态问答答案chat_id={converse_params.chat_id}')
return StreamingResponse(
stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='static_qa'),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Transfer-Encoding': 'chunked'
}
)
except Exception as e:
logger.warning(f'[StaticQA] 静态问答匹配失败: {e}')
# ========== 2. RAG历史缓存查找 ==========
logger.info(f'[SemanticCache] 准备执行RAG历史缓存查找 | redis={redis is not None} | chat_id={converse_params.chat_id}')
if redis:
lookup_hash = _get_question_hash(converse_params.question)
logger.info(f'[SemanticCache] 开始查找 | chat_id={converse_params.chat_id} | question={converse_params.question} | hash={lookup_hash} | threshold=0.60')
cache_result = await lookup_question(
converse_params.chat_id,
converse_params.question,
redis
)
logger.info(f'[SemanticCache] 查找结果 | found={cache_result is not None}')
if cache_result:
cached_answer, cache_similarity = cache_result
cache_source = 'rag_history'
logger.info(f'[RAG_SOURCE] 命中RAG会话历史 | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}')
logger.info(f'[SemanticCache] 流式响应使用RAG历史缓存答案chat_id={converse_params.chat_id}')
return StreamingResponse(
stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='rag_history'),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Transfer-Encoding': 'chunked'
}
)
# 检查是否应该使用搜索服务
try:
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
except Exception as e:
logger.warning(f"意图分类失败使用RAG服务: {e}")
# 直接使用同步RAGFlow服务
try:
# 恢复原始问题(包含语言风格提示词)
if style_removed:
converse_params.question = original_question
logger.info(f'[StyleHint] 恢复原始问题: {original_question}')
logger.info(f'[RAG_SOURCE] 调用原生RAG服务 | chat_id={converse_params.chat_id} | question={converse_params.question}')
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
cache_question = cleaned_question if style_removed else converse_params.question
store_hash = _get_question_hash(cache_question)
logger.info(f'[RAG_CACHE] 准备存储 | chat_id={converse_params.chat_id} | question={cache_question} | hash={store_hash}')
async def make_cache_store(chat_id: str, question: str):
async def store_answer(answer: str):
if redis and answer and len(answer.strip()) >= 10:
try:
await _async_store_qa(chat_id, question, answer, redis)
logger.info(f'[RAG_CACHE] 语义缓存存储成功 | chat_id={chat_id} | answer_length={len(answer)}')
except Exception as e:
logger.warning(f'语义缓存存储失败: {e}')
return store_answer
cache_store = await make_cache_store(converse_params.chat_id, cache_question)
return StreamingResponse(
stream_ragflow_response(result, converse_params.chat_id, start_time, cache_store_func=cache_store),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
'X-Accel-Buffering': 'no',
'Transfer-Encoding': 'chunked'
}
)
except Exception as e:
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float, cache_store_func=None) -> Generator[str, None, None]:
"""
流式处理RAGFlow响应 - 同步版本修复首token延迟
"""
import time
server_stream_start = time.time()
logger.info(f"[RAG_SERVER {server_stream_start:.3f}] 🚀 开始流式响应处理chat_id: {chat_id}")
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
# 立即发送连接建立消息解决首token延迟
connection_time = time.time()
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
yield ping_message
logger.info(f"[RAG_SERVER {connection_time:.3f}] 📡 连接建立ping发送耗时: {connection_time - server_stream_start:.3f}s")
chunk_count = 0
try:
for chunk in result:
chunk_time = time.time()
chunk_count += 1
# 检查第一个token的延迟
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🎯 首token到达耗时: {latency:.3f}s")
# 处理chunk数据
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ 空chunk跳过")
continue
body = payload if isinstance(payload, dict) else {'data': payload}
# 处理错误
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"[RAG_SERVER {chunk_time:.3f}] ❌ RAGFlow Stream Error: {chunk}")
error_message = format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
yield error_message
continue
# 处理answer字段的增量更新
if isinstance(body, dict) and 'answer' in body:
current_answer = body['answer']
# 安全检查确保current_answer不为None且为字符串
if current_answer is None:
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer为None跳过处理")
continue
if not isinstance(current_answer, str):
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer不是字符串类型: {type(current_answer)}")
current_answer = str(current_answer)
# 计算增量内容
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📝 Chunk #{chunk_count} 处理完成delta长度: {len(delta)}")
yield format_sse(body)
else:
# 上下文重置的备用处理
last_answer = current_answer
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🔄 Chunk #{chunk_count} 上下文重置")
yield format_sse(body)
else:
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📦 Chunk #{chunk_count} 其他数据: {body}")
yield format_sse(body)
# 流结束
stream_end_time = time.time()
end_message = format_sse({'status': 'completed'}, event='end')
yield end_message
logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 流式响应完成chat_id: {chat_id}')
logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 总共处理chunk数量: {chunk_count}')
logger.info(f'[RAG_SOURCE] 原生RAG流式响应完成 | chat_id={chat_id} | total_chunks={chunk_count} | answer_length={len(last_answer)}')
except Exception as exc:
error_time = time.time()
logger.exception(f'[RAG_SERVER {error_time:.3f}] ragflow流式对话异常: {exc}')
error_message = format_sse({'message': str(exc)}, event='error')
yield error_message
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s')
# 流结束后存储缓存
if cache_store_func and last_answer and len(last_answer.strip()) >= 10:
try:
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(cache_store_func(last_answer))
logger.info(f'[RAG_CACHE] 缓存存储完成 | chat_id={chat_id} | answer_length={len(last_answer)}')
finally:
loop.close()
except Exception as cache_err:
logger.warning(f'[RAG_CACHE] 缓存存储失败: {cache_err}')
# 聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel, # 使用正确的Pydantic模型
):
"""获取聊天助手列表"""
result = RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# 创建会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型
):
"""创建聊天会话"""
result = RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型
):
"""更新聊天助手"""
result = RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 数据集相关接口
@ragflowController.post('/dataset_list')
async def get_system_ragflow_list(
request: Request,
ragflow_list_query: Any, # 使用Any避免导入具体模型类
):
"""获取数据集列表"""
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
return parse_result(result)
@ragflowController.post('/create_dataset')
async def create_dataset(
create_dataset_params: Any, # 使用Any避免导入具体模型类
):
"""创建数据集"""
result = RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: Any, # 使用Any避免导入具体模型类
):
"""删除数据集"""
result = RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
def stream_cached_response(cached_answer: str, chat_id: str, start_time: float, cache_source: str = 'cache') -> Generator[str, None, None]:
"""
流式返回缓存的答案
Args:
cached_answer: 缓存的答案文本
chat_id: 会话ID
start_time: 请求开始时间
cache_source: 缓存来源 ('static_qa''rag_history')
"""
import time
server_stream_start = time.time()
logger.info(f"[CACHE_STREAM {server_stream_start:.3f}] 🚀 开始流式返回缓存答案chat_id: {chat_id}")
# 立即发送连接建立消息
connection_time = time.time()
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
yield ping_message
logger.info(f"[CACHE_STREAM {connection_time:.3f}] 📡 连接建立ping发送耗时: {connection_time - server_stream_start:.3f}s")
# 模拟流式输出效果(可选:如果是完整答案,可以选择立即返回或模拟流式)
# 这里我们选择模拟流式输出,每次发送一小部分,模拟打字效果
if cached_answer:
# 将答案分块发送,模拟打字效果
chunk_size = 10 # 每个块10个字符
answer_chunks = [cached_answer[i:i+chunk_size] for i in range(0, len(cached_answer), chunk_size)]
first_token_time = time.time()
for i, chunk in enumerate(answer_chunks):
chunk_time = time.time()
# 首token延迟
if i == 0:
latency = time.perf_counter() - start_time
logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 🎯 首token到达缓存耗时: {latency:.3f}s")
body = {'answer': chunk, 'chunk_index': i, 'total_chunks': len(answer_chunks)}
yield format_sse(body)
# 小延迟模拟打字效果(可调整或移除)
# time.sleep(0.01)
logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 📝 缓存答案流式发送完成,共 {len(answer_chunks)} 个chunk")
# 流结束
stream_end_time = time.time()
end_message = format_sse({
'status': 'completed',
'from_cache': True,
'source': cache_source,
'total_time': stream_end_time - server_stream_start
}, event='end')
yield end_message
logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 🏁 缓存流式响应完成")
if cached_answer:
logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 📊 缓存答案长度: {len(cached_answer)} 字符")
total_time = time.perf_counter() - start_time
logger.info(f'[CACHE_STREAM {time.time():.3f}] ⏱️ Total Cache Stream Duration ({chat_id}): {total_time:.3f}s')