479 lines
20 KiB
Python
479 lines
20 KiB
Python
"""
|
||
简化的RAGFlow控制器 - 集成语义缓存
|
||
"""
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from typing import Any, Dict, Generator, Optional, Tuple
|
||
|
||
from fastapi import APIRouter, Depends, Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
from module_admin.service.login_service import LoginService
|
||
from module_admin.service.ragflow_service import RAGFlowService
|
||
from module_admin.service.search_service import SearchService
|
||
from module_admin.entity.vo.ragflow_vo import (
|
||
ConverseWithChatAssistantModel,
|
||
CreateSessionWithChatModel,
|
||
UpdateChatAssistantModel,
|
||
RagflowListQueryModel
|
||
)
|
||
from utils.log_util import logger
|
||
from utils.response_util import ResponseUtil
|
||
from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair
|
||
from utils.static_qa_service import get_static_qa_service
|
||
|
||
|
||
async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None:
|
||
"""
|
||
异步存储问答对到语义缓存
|
||
"""
|
||
try:
|
||
await store_qa_pair(chat_id, question, answer, redis)
|
||
except Exception as e:
|
||
logger.warning(f"[SemanticCache] 异步存储失败: {e}")
|
||
|
||
|
||
# 使用标准的APIRouter,增加认证依赖
|
||
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||
|
||
|
||
def format_sse(data: dict, event: str | None = None) -> str:
|
||
"""格式化SSE数据"""
|
||
payload = json.dumps(data, ensure_ascii=False)
|
||
prefix = f'event: {event}\n' if event else ''
|
||
return f'{prefix}data: {payload}\n\n'
|
||
|
||
|
||
def parse_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""统一结果解析"""
|
||
code = result.get('code', 0)
|
||
if code != 0:
|
||
msg = result.get('message') or result.get('msg') or '接口异常'
|
||
return ResponseUtil.error(msg=msg, data=result.get('data', None))
|
||
return ResponseUtil.success(data=result.get('data', None))
|
||
|
||
|
||
def build_chat_cache_key(chat_id: str, question: str) -> str:
|
||
"""构建聊天缓存键"""
|
||
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
|
||
return f'ragflow:chat:{chat_id}:{digest}'
|
||
|
||
|
||
def remove_style_hint(question: str) -> Tuple[str, bool]:
|
||
"""
|
||
移除语言风格提示词
|
||
|
||
Args:
|
||
question: 原始问题
|
||
|
||
Returns:
|
||
(处理后的问题, 是否移除了风格提示词)
|
||
"""
|
||
if not question:
|
||
return question, False
|
||
|
||
import re
|
||
|
||
pattern = r'语言风格\s*[::]\s*[^,。!?\n]*'
|
||
match = re.search(pattern, question)
|
||
|
||
if match:
|
||
style_hint = match.group(0)
|
||
cleaned_question = re.sub(pattern, '', question).strip()
|
||
cleaned_question = re.sub(r'\s+', ' ', cleaned_question)
|
||
logger.info(f'[StyleHint] 移除风格提示词: {style_hint} | 原始问题: {question} | 处理后: {cleaned_question}')
|
||
return cleaned_question, True
|
||
|
||
return question, False
|
||
|
||
|
||
# 非流式对话接口
|
||
@ragflowController.post('/converse_with_chat_assistant')
|
||
async def converse_with_chat_assistant(
|
||
request: Request,
|
||
converse_params: ConverseWithChatAssistantModel,
|
||
):
|
||
"""
|
||
与聊天助手进行对话 - 集成语义缓存版本(支持流式和非流式)
|
||
|
||
匹配流程:
|
||
1. 移除语言风格提示词
|
||
2. 静态问答匹配 (threshold=0.70)
|
||
3. RAG历史缓存匹配 (threshold=0.60)
|
||
4. RAG服务调用(使用原始问题)
|
||
"""
|
||
start_time = time.perf_counter()
|
||
|
||
original_question = converse_params.question
|
||
cleaned_question, style_removed = remove_style_hint(original_question)
|
||
|
||
if style_removed:
|
||
converse_params.question = cleaned_question
|
||
|
||
# 获取redis实例
|
||
redis = getattr(request.app.state, 'redis', None)
|
||
cache_key = None
|
||
cached_answer = None
|
||
cache_similarity = 0.0
|
||
cache_source = None
|
||
|
||
# ========== 1. 静态问答匹配 ==========
|
||
try:
|
||
static_qa_service = get_static_qa_service()
|
||
logger.info(f'[StaticQA] 开始匹配 | chat_id={converse_params.chat_id} | question={converse_params.question} | threshold=0.70')
|
||
static_match, static_sim = static_qa_service.find_match(converse_params.question, threshold=0.70)
|
||
logger.info(f'[StaticQA] 匹配结果 | similarity={static_sim:.2f} | matched={static_match is not None}')
|
||
|
||
if static_match:
|
||
cached_answer = static_match.get('answer', '')
|
||
cache_similarity = static_sim
|
||
cache_source = 'static_qa'
|
||
logger.info(f'[RAG_SOURCE] 命中静态FAQ | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}')
|
||
|
||
# 非流式:直接返回静态问答答案
|
||
if not converse_params.stream:
|
||
return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': cache_similarity, 'source': 'static_qa'})
|
||
|
||
# 流式:使用静态问答答案进行流式响应
|
||
if converse_params.stream:
|
||
logger.info(f'[StaticQA] 流式响应使用静态问答答案,chat_id={converse_params.chat_id}')
|
||
return StreamingResponse(
|
||
stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='static_qa'),
|
||
media_type='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
'X-Accel-Buffering': 'no',
|
||
'Transfer-Encoding': 'chunked'
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f'[StaticQA] 静态问答匹配失败: {e}')
|
||
|
||
# ========== 2. RAG历史缓存查找 ==========
|
||
if redis:
|
||
logger.info(f'[SemanticCache] 开始查找 | chat_id={converse_params.chat_id} | question={converse_params.question} | threshold=0.60')
|
||
cache_result = await lookup_question(
|
||
converse_params.chat_id,
|
||
converse_params.question,
|
||
redis
|
||
)
|
||
logger.info(f'[SemanticCache] 查找结果 | found={cache_result is not None}')
|
||
if cache_result:
|
||
cached_answer, cache_similarity = cache_result
|
||
cache_source = 'rag_history'
|
||
logger.info(f'[RAG_SOURCE] 命中RAG会话历史 | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}')
|
||
|
||
# 非流式:直接返回缓存答案
|
||
if not converse_params.stream:
|
||
return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': cache_similarity, 'source': 'rag_history'})
|
||
|
||
# 流式:使用缓存答案进行流式响应
|
||
if converse_params.stream:
|
||
logger.info(f'[SemanticCache] 流式响应使用RAG历史缓存答案,chat_id={converse_params.chat_id}')
|
||
return StreamingResponse(
|
||
stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='rag_history'),
|
||
media_type='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
'X-Accel-Buffering': 'no',
|
||
'Transfer-Encoding': 'chunked'
|
||
}
|
||
)
|
||
|
||
# 检查是否应该使用搜索服务
|
||
try:
|
||
intent = await SearchService.classify_intent(converse_params.question)
|
||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}")
|
||
|
||
if intent == 'SEARCH':
|
||
return await SearchService.handle_search_chat(converse_params, redis)
|
||
except Exception as e:
|
||
logger.warning(f"意图分类失败,使用RAG服务: {e}")
|
||
|
||
# 缓存检查(原有的精确缓存)- 非流式
|
||
if not converse_params.stream and redis:
|
||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||
try:
|
||
cached = await redis.get(cache_key)
|
||
if cached:
|
||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||
return ResponseUtil.success(json.loads(cached))
|
||
except Exception as e:
|
||
logger.warning(f"缓存获取失败: {e}")
|
||
|
||
# 直接使用同步RAGFlow服务
|
||
try:
|
||
# 恢复原始问题(包含语言风格提示词)
|
||
if style_removed:
|
||
converse_params.question = original_question
|
||
logger.info(f'[StyleHint] 恢复原始问题: {original_question}')
|
||
|
||
logger.info(f'[RAG_SOURCE] 调用原生RAG服务 | chat_id={converse_params.chat_id} | question={converse_params.question}')
|
||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||
|
||
# 流式响应
|
||
if converse_params.stream:
|
||
# 注意:存储时使用清理后的问题(与查找时保持一致)
|
||
cache_question = cleaned_question if style_removed else converse_params.question
|
||
|
||
# 创建缓存存储回调函数
|
||
async def make_cache_store(chat_id: str, question: str):
|
||
async def store_answer(answer: str):
|
||
if redis and answer and len(answer.strip()) >= 10:
|
||
try:
|
||
await _async_store_qa(chat_id, question, answer, redis)
|
||
logger.info(f'[RAG_CACHE] 语义缓存存储成功 | chat_id={chat_id} | answer_length={len(answer)}')
|
||
except Exception as e:
|
||
logger.warning(f'语义缓存存储失败: {e}')
|
||
return store_answer
|
||
|
||
cache_store = await make_cache_store(converse_params.chat_id, cache_question)
|
||
|
||
return StreamingResponse(
|
||
stream_ragflow_response(result, converse_params.chat_id, start_time, cache_store_func=cache_store),
|
||
media_type='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
'X-Accel-Buffering': 'no',
|
||
'Transfer-Encoding': 'chunked'
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}')
|
||
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
|
||
|
||
|
||
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float, cache_store_func=None) -> Generator[str, None, None]:
|
||
"""
|
||
流式处理RAGFlow响应 - 同步版本,修复首token延迟
|
||
"""
|
||
import time
|
||
|
||
server_stream_start = time.time()
|
||
logger.info(f"[RAG_SERVER {server_stream_start:.3f}] 🚀 开始流式响应处理,chat_id: {chat_id}")
|
||
|
||
last_answer = ""
|
||
first_token_received = False
|
||
start_stream_time = time.perf_counter()
|
||
|
||
# 立即发送连接建立消息,解决首token延迟
|
||
connection_time = time.time()
|
||
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
|
||
yield ping_message
|
||
logger.info(f"[RAG_SERVER {connection_time:.3f}] 📡 连接建立ping发送,耗时: {connection_time - server_stream_start:.3f}s")
|
||
|
||
chunk_count = 0
|
||
try:
|
||
for chunk in result:
|
||
chunk_time = time.time()
|
||
chunk_count += 1
|
||
|
||
# 检查第一个token的延迟
|
||
if not first_token_received:
|
||
first_token_received = True
|
||
latency = time.perf_counter() - start_stream_time
|
||
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🎯 首token到达,耗时: {latency:.3f}s")
|
||
|
||
# 处理chunk数据
|
||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||
if not payload:
|
||
logger.info(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ 空chunk跳过")
|
||
continue
|
||
|
||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||
|
||
# 处理错误
|
||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||
logger.error(f"[RAG_SERVER {chunk_time:.3f}] ❌ RAGFlow Stream Error: {chunk}")
|
||
error_message = format_sse({'message': chunk.get('message', '流式处理异常')}, event='error')
|
||
yield error_message
|
||
continue
|
||
|
||
# 处理answer字段的增量更新
|
||
if isinstance(body, dict) and 'answer' in body:
|
||
current_answer = body['answer']
|
||
|
||
# 安全检查:确保current_answer不为None且为字符串
|
||
if current_answer is None:
|
||
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer为None,跳过处理")
|
||
continue
|
||
|
||
if not isinstance(current_answer, str):
|
||
logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer不是字符串,类型: {type(current_answer)}")
|
||
current_answer = str(current_answer)
|
||
|
||
# 计算增量内容
|
||
if current_answer.startswith(last_answer):
|
||
delta = current_answer[len(last_answer):]
|
||
if delta:
|
||
body['answer'] = delta
|
||
last_answer = current_answer
|
||
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📝 Chunk #{chunk_count} 处理完成,delta长度: {len(delta)}")
|
||
yield format_sse(body)
|
||
else:
|
||
# 上下文重置的备用处理
|
||
last_answer = current_answer
|
||
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🔄 Chunk #{chunk_count} 上下文重置")
|
||
yield format_sse(body)
|
||
else:
|
||
logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📦 Chunk #{chunk_count} 其他数据: {body}")
|
||
yield format_sse(body)
|
||
|
||
# 流结束
|
||
stream_end_time = time.time()
|
||
end_message = format_sse({'status': 'completed'}, event='end')
|
||
yield end_message
|
||
logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 流式响应完成,chat_id: {chat_id}')
|
||
logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 总共处理chunk数量: {chunk_count}')
|
||
logger.info(f'[RAG_SOURCE] 原生RAG流式响应完成 | chat_id={chat_id} | total_chunks={chunk_count} | answer_length={len(last_answer)}')
|
||
|
||
except Exception as exc:
|
||
error_time = time.time()
|
||
logger.exception(f'[RAG_SERVER {error_time:.3f}] ragflow流式对话异常: {exc}')
|
||
error_message = format_sse({'message': str(exc)}, event='error')
|
||
yield error_message
|
||
finally:
|
||
total_time = time.perf_counter() - start_time
|
||
logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s')
|
||
|
||
# 流结束后存储缓存
|
||
if cache_store_func and last_answer and len(last_answer.strip()) >= 10:
|
||
try:
|
||
import asyncio
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
loop.run_until_complete(cache_store_func(last_answer))
|
||
logger.info(f'[RAG_CACHE] 缓存存储完成 | chat_id={chat_id} | answer_length={len(last_answer)}')
|
||
finally:
|
||
loop.close()
|
||
except Exception as cache_err:
|
||
logger.warning(f'[RAG_CACHE] 缓存存储失败: {cache_err}')
|
||
|
||
|
||
# 聊天助手列表
|
||
@ragflowController.post('/get_chat_assistant_list')
|
||
async def get_chat_assistant_list(
|
||
query_params: RagflowListQueryModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""获取聊天助手列表"""
|
||
result = RAGFlowService.get_chat_assistant_list_services(query_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 创建会话
|
||
@ragflowController.post('/create_session_with_chat')
|
||
async def create_session_with_chat(
|
||
create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""创建聊天会话"""
|
||
result = RAGFlowService.create_session_with_chat_services(create_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 更新聊天助手
|
||
@ragflowController.post('/update_chat_assistant')
|
||
async def update_chat_assistant(
|
||
update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型
|
||
):
|
||
"""更新聊天助手"""
|
||
result = RAGFlowService.update_chat_assistant_services(update_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 数据集相关接口
|
||
@ragflowController.post('/dataset_list')
|
||
async def get_system_ragflow_list(
|
||
request: Request,
|
||
ragflow_list_query: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""获取数据集列表"""
|
||
result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query)
|
||
return parse_result(result)
|
||
|
||
|
||
@ragflowController.post('/create_dataset')
|
||
async def create_dataset(
|
||
create_dataset_params: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""创建数据集"""
|
||
result = RAGFlowService.create_dataset_services(create_dataset_params)
|
||
return parse_result(result)
|
||
|
||
|
||
@ragflowController.post('/delete_datasets')
|
||
async def delete_datasets(
|
||
delete_params: Any, # 使用Any避免导入具体模型类
|
||
):
|
||
"""删除数据集"""
|
||
result = RAGFlowService.delete_datasets_services(delete_params)
|
||
return parse_result(result)
|
||
|
||
|
||
def stream_cached_response(cached_answer: str, chat_id: str, start_time: float, cache_source: str = 'cache') -> Generator[str, None, None]:
|
||
"""
|
||
流式返回缓存的答案
|
||
|
||
Args:
|
||
cached_answer: 缓存的答案文本
|
||
chat_id: 会话ID
|
||
start_time: 请求开始时间
|
||
cache_source: 缓存来源 ('static_qa' 或 'rag_history')
|
||
"""
|
||
import time
|
||
|
||
server_stream_start = time.time()
|
||
logger.info(f"[CACHE_STREAM {server_stream_start:.3f}] 🚀 开始流式返回缓存答案,chat_id: {chat_id}")
|
||
|
||
# 立即发送连接建立消息
|
||
connection_time = time.time()
|
||
ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n"
|
||
yield ping_message
|
||
logger.info(f"[CACHE_STREAM {connection_time:.3f}] 📡 连接建立ping发送,耗时: {connection_time - server_stream_start:.3f}s")
|
||
|
||
# 模拟流式输出效果(可选:如果是完整答案,可以选择立即返回或模拟流式)
|
||
# 这里我们选择模拟流式输出,每次发送一小部分,模拟打字效果
|
||
if cached_answer:
|
||
# 将答案分块发送,模拟打字效果
|
||
chunk_size = 10 # 每个块10个字符
|
||
answer_chunks = [cached_answer[i:i+chunk_size] for i in range(0, len(cached_answer), chunk_size)]
|
||
|
||
first_token_time = time.time()
|
||
for i, chunk in enumerate(answer_chunks):
|
||
chunk_time = time.time()
|
||
|
||
# 首token延迟
|
||
if i == 0:
|
||
latency = time.perf_counter() - start_time
|
||
logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 🎯 首token到达(缓存),耗时: {latency:.3f}s")
|
||
|
||
body = {'answer': chunk, 'chunk_index': i, 'total_chunks': len(answer_chunks)}
|
||
yield format_sse(body)
|
||
|
||
# 小延迟模拟打字效果(可调整或移除)
|
||
# time.sleep(0.01)
|
||
|
||
logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 📝 缓存答案流式发送完成,共 {len(answer_chunks)} 个chunk")
|
||
|
||
# 流结束
|
||
stream_end_time = time.time()
|
||
end_message = format_sse({
|
||
'status': 'completed',
|
||
'from_cache': True,
|
||
'source': cache_source,
|
||
'total_time': stream_end_time - server_stream_start
|
||
}, event='end')
|
||
yield end_message
|
||
|
||
logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 🏁 缓存流式响应完成")
|
||
logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 📊 缓存答案长度: {len(cached_answer)} 字符")
|
||
|
||
total_time = time.perf_counter() - start_time
|
||
logger.info(f'[CACHE_STREAM {time.time():.3f}] ⏱️ Total Cache Stream Duration ({chat_id}): {total_time:.3f}s') |