""" 简化的RAGFlow控制器 - 集成语义缓存 """ import asyncio import hashlib import json import time from typing import Any, Dict, Generator, Optional, Tuple from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse from module_admin.service.login_service import LoginService from module_admin.service.ragflow_service import RAGFlowService from module_admin.service.search_service import SearchService from module_admin.entity.vo.ragflow_vo import ( ConverseWithChatAssistantModel, CreateSessionWithChatModel, UpdateChatAssistantModel, RagflowListQueryModel ) from utils.log_util import logger from utils.response_util import ResponseUtil from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair, get_question_hash from utils.static_qa_service import get_static_qa_service async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None: """ 异步存储问答对到语义缓存 """ store_hash = get_question_hash(question) logger.info(f"[SemanticCache] 存储QA | chat_id={chat_id} | question={question} | hash={store_hash} | answer_length={len(answer)}") try: await store_qa_pair(chat_id, question, answer, redis) except Exception as e: logger.warning(f"[SemanticCache] 异步存储失败: {e}") # 使用标准的APIRouter,增加认证依赖 ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)]) def format_sse(data: dict, event: str | None = None) -> str: """格式化SSE数据""" payload = json.dumps(data, ensure_ascii=False) prefix = f'event: {event}\n' if event else '' return f'{prefix}data: {payload}\n\n' def parse_result(result: Dict[str, Any]) -> Dict[str, Any]: """统一结果解析""" code = result.get('code', 0) if code != 0: msg = result.get('message') or result.get('msg') or '接口异常' return ResponseUtil.error(msg=msg, data=result.get('data', None)) return ResponseUtil.success(data=result.get('data', None)) def build_chat_cache_key(chat_id: str, question: str) -> str: """构建聊天缓存键""" digest = hashlib.sha256(question.encode('utf-8')).hexdigest() return f'ragflow:chat:{chat_id}:{digest}' def remove_style_hint(question: str) -> Tuple[str, bool]: """ 移除语言风格提示词 Args: question: 原始问题 Returns: (处理后的问题, 是否移除了风格提示词) """ if not question: return question, False import re pattern = r'语言风格\s*[::]\s*[^,。!?\n]*' match = re.search(pattern, question) if match: style_hint = match.group(0) cleaned_question = re.sub(pattern, '', question).strip() cleaned_question = re.sub(r'\s+', ' ', cleaned_question) logger.info(f'[StyleHint] 移除风格提示词: {style_hint} | 原始问题: {question} | 处理后: {cleaned_question}') return cleaned_question, True return question, False # 非流式对话接口 @ragflowController.post('/converse_with_chat_assistant') async def converse_with_chat_assistant( request: Request, converse_params: ConverseWithChatAssistantModel, ): """ 与聊天助手进行对话 - 集成语义缓存版本(支持流式和非流式) 匹配流程: 1. 移除语言风格提示词 2. 静态问答匹配 (threshold=0.70) 3. RAG历史缓存匹配 (threshold=0.60) 4. RAG服务调用(使用原始问题) """ start_time = time.perf_counter() original_question = converse_params.question cleaned_question, style_removed = remove_style_hint(original_question) if style_removed: converse_params.question = cleaned_question # 获取redis实例 redis = getattr(request.app.state, 'redis', None) cache_key = None cached_answer = None cache_similarity = 0.0 cache_source = None # ========== 1. 静态问答匹配 ========== try: static_qa_service = get_static_qa_service() logger.info(f'[StaticQA] 开始匹配 | chat_id={converse_params.chat_id} | question={converse_params.question} | threshold=0.70') static_match, static_sim = static_qa_service.find_match(converse_params.question, threshold=0.70) logger.info(f'[StaticQA] 匹配结果 | similarity={static_sim:.2f} | matched={static_match is not None}') if static_match: cached_answer = static_match.get('answer', '') cache_similarity = static_sim cache_source = 'static_qa' logger.info(f'[RAG_SOURCE] 命中静态FAQ | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}') logger.info(f'[StaticQA] 流式响应使用静态问答答案,chat_id={converse_params.chat_id}') return StreamingResponse( stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='static_qa'), media_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no', 'Transfer-Encoding': 'chunked' } ) except Exception as e: logger.warning(f'[StaticQA] 静态问答匹配失败: {e}') # ========== 2. RAG历史缓存查找 ========== logger.info(f'[SemanticCache] 准备执行RAG历史缓存查找 | redis={redis is not None} | chat_id={converse_params.chat_id}') if redis: lookup_hash = get_question_hash(converse_params.question) logger.info(f'[SemanticCache] 开始查找 | chat_id={converse_params.chat_id} | question={converse_params.question} | hash={lookup_hash} | threshold=0.60') service = get_semantic_cache_service() logger.info(f'[SemanticCache] service实例: {service}') cache_result = await service.lookup( converse_params.chat_id, converse_params.question, redis ) logger.info(f'[SemanticCache] 查找结果 | found={cache_result is not None}') if cache_result: cached_answer, cache_similarity = cache_result cache_source = 'rag_history' logger.info(f'[RAG_SOURCE] 命中RAG会话历史 | chat_id={converse_params.chat_id} | question={converse_params.question} | similarity={cache_similarity:.2f} | answer_length={len(cached_answer)}') logger.info(f'[SemanticCache] 流式响应使用RAG历史缓存答案,chat_id={converse_params.chat_id}') return StreamingResponse( stream_cached_response(cached_answer, converse_params.chat_id, start_time, cache_source='rag_history'), media_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no', 'Transfer-Encoding': 'chunked' } ) # 检查是否应该使用搜索服务 try: intent = await SearchService.classify_intent(converse_params.question) logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent}") if intent == 'SEARCH': return await SearchService.handle_search_chat(converse_params, redis) except Exception as e: logger.warning(f"意图分类失败,使用RAG服务: {e}") # 直接使用同步RAGFlow服务 try: # 恢复原始问题(包含语言风格提示词) if style_removed: converse_params.question = original_question logger.info(f'[StyleHint] 恢复原始问题: {original_question}') logger.info(f'[RAG_SOURCE] 调用原生RAG服务 | chat_id={converse_params.chat_id} | question={converse_params.question}') result = RAGFlowService.converse_with_chat_assistant_services(converse_params) cache_question = cleaned_question if style_removed else converse_params.question store_hash = get_question_hash(cache_question) logger.info(f'[RAG_CACHE] 准备存储 | chat_id={converse_params.chat_id} | question={cache_question} | hash={store_hash}') async def make_cache_store(chat_id: str, question: str): async def store_answer(answer: str): if redis and answer and len(answer.strip()) >= 10: try: await _async_store_qa(chat_id, question, answer, redis) logger.info(f'[RAG_CACHE] 语义缓存存储成功 | chat_id={chat_id} | answer_length={len(answer)}') except Exception as e: logger.warning(f'语义缓存存储失败: {e}') return store_answer cache_store = await make_cache_store(converse_params.chat_id, cache_question) return StreamingResponse( stream_ragflow_response(result, converse_params.chat_id, start_time, cache_store_func=cache_store), media_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no', 'Transfer-Encoding': 'chunked' } ) except Exception as e: logger.exception(f'[RAG_PROFILE] ragflow对话异常: {e}') return ResponseUtil.error(msg=f'对话服务异常: {str(e)}') def stream_ragflow_response(result: Generator, chat_id: str, start_time: float, cache_store_func=None) -> Generator[str, None, None]: """ 流式处理RAGFlow响应 - 同步版本,修复首token延迟 """ import time server_stream_start = time.time() logger.info(f"[RAG_SERVER {server_stream_start:.3f}] 🚀 开始流式响应处理,chat_id: {chat_id}") last_answer = "" first_token_received = False start_stream_time = time.perf_counter() # 立即发送连接建立消息,解决首token延迟 connection_time = time.time() ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n" yield ping_message logger.info(f"[RAG_SERVER {connection_time:.3f}] 📡 连接建立ping发送,耗时: {connection_time - server_stream_start:.3f}s") chunk_count = 0 try: for chunk in result: chunk_time = time.time() chunk_count += 1 # 检查第一个token的延迟 if not first_token_received: first_token_received = True latency = time.perf_counter() - start_stream_time logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🎯 首token到达,耗时: {latency:.3f}s") # 处理chunk数据 payload = chunk.get('data') if isinstance(chunk, dict) else chunk if not payload: logger.info(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ 空chunk跳过") continue body = payload if isinstance(payload, dict) else {'data': payload} # 处理错误 if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0: logger.error(f"[RAG_SERVER {chunk_time:.3f}] ❌ RAGFlow Stream Error: {chunk}") error_message = format_sse({'message': chunk.get('message', '流式处理异常')}, event='error') yield error_message continue # 处理answer字段的增量更新 if isinstance(body, dict) and 'answer' in body: current_answer = body['answer'] # 安全检查:确保current_answer不为None且为字符串 if current_answer is None: logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer为None,跳过处理") continue if not isinstance(current_answer, str): logger.warning(f"[RAG_SERVER {chunk_time:.3f}] ⚠️ current_answer不是字符串,类型: {type(current_answer)}") current_answer = str(current_answer) # 计算增量内容 if current_answer.startswith(last_answer): delta = current_answer[len(last_answer):] if delta: body['answer'] = delta last_answer = current_answer logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📝 Chunk #{chunk_count} 处理完成,delta长度: {len(delta)}") yield format_sse(body) else: # 上下文重置的备用处理 last_answer = current_answer logger.info(f"[RAG_SERVER {chunk_time:.3f}] 🔄 Chunk #{chunk_count} 上下文重置") yield format_sse(body) else: logger.info(f"[RAG_SERVER {chunk_time:.3f}] 📦 Chunk #{chunk_count} 其他数据: {body}") yield format_sse(body) # 流结束 stream_end_time = time.time() end_message = format_sse({'status': 'completed'}, event='end') yield end_message logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 流式响应完成,chat_id: {chat_id}') logger.info(f'[RAG_SERVER {stream_end_time:.3f}] 总共处理chunk数量: {chunk_count}') logger.info(f'[RAG_SOURCE] 原生RAG流式响应完成 | chat_id={chat_id} | total_chunks={chunk_count} | answer_length={len(last_answer)}') except Exception as exc: error_time = time.time() logger.exception(f'[RAG_SERVER {error_time:.3f}] ragflow流式对话异常: {exc}') error_message = format_sse({'message': str(exc)}, event='error') yield error_message finally: total_time = time.perf_counter() - start_time logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s') if cache_store_func and last_answer and len(last_answer.strip()) >= 10: try: import asyncio cache_task = asyncio.create_task(cache_store_func(last_answer)) cache_task.add_done_callback( lambda t: logger.info(f'[RAG_CACHE] 缓存存储完成 | chat_id={chat_id} | answer_length={len(last_answer)}') if not t.exception() else logger.warning(f'[RAG_CACHE] 缓存存储失败: {t.exception()}') ) except Exception as cache_err: logger.warning(f'[RAG_CACHE] 缓存存储失败: {cache_err}') # 聊天助手列表 @ragflowController.post('/get_chat_assistant_list') async def get_chat_assistant_list( query_params: RagflowListQueryModel, # 使用正确的Pydantic模型 ): """获取聊天助手列表""" result = RAGFlowService.get_chat_assistant_list_services(query_params) return parse_result(result) # 创建会话 @ragflowController.post('/create_session_with_chat') async def create_session_with_chat( create_params: CreateSessionWithChatModel, # 使用正确的Pydantic模型 ): """创建聊天会话""" result = RAGFlowService.create_session_with_chat_services(create_params) return parse_result(result) # 更新聊天助手 @ragflowController.post('/update_chat_assistant') async def update_chat_assistant( update_params: UpdateChatAssistantModel, # 使用正确的Pydantic模型 ): """更新聊天助手""" result = RAGFlowService.update_chat_assistant_services(update_params) return parse_result(result) # 数据集相关接口 @ragflowController.post('/dataset_list') async def get_system_ragflow_list( request: Request, ragflow_list_query: Any, # 使用Any避免导入具体模型类 ): """获取数据集列表""" result = RAGFlowService.get_ragflow_dataset_list_services(None, ragflow_list_query) return parse_result(result) @ragflowController.post('/create_dataset') async def create_dataset( create_dataset_params: Any, # 使用Any避免导入具体模型类 ): """创建数据集""" result = RAGFlowService.create_dataset_services(create_dataset_params) return parse_result(result) @ragflowController.post('/delete_datasets') async def delete_datasets( delete_params: Any, # 使用Any避免导入具体模型类 ): """删除数据集""" result = RAGFlowService.delete_datasets_services(delete_params) return parse_result(result) def stream_cached_response(cached_answer: str, chat_id: str, start_time: float, cache_source: str = 'cache') -> Generator[str, None, None]: """ 流式返回缓存的答案 Args: cached_answer: 缓存的答案文本 chat_id: 会话ID start_time: 请求开始时间 cache_source: 缓存来源 ('static_qa' 或 'rag_history') """ import time server_stream_start = time.time() logger.info(f"[CACHE_STREAM {server_stream_start:.3f}] 🚀 开始流式返回缓存答案,chat_id: {chat_id}") # 立即发送连接建立消息 connection_time = time.time() ping_message = "event: ping\ndata: {\"status\": \"connected\"}\n\n" yield ping_message logger.info(f"[CACHE_STREAM {connection_time:.3f}] 📡 连接建立ping发送,耗时: {connection_time - server_stream_start:.3f}s") # 模拟流式输出效果(可选:如果是完整答案,可以选择立即返回或模拟流式) # 这里我们选择模拟流式输出,每次发送一小部分,模拟打字效果 if cached_answer: # 将答案分块发送,模拟打字效果 chunk_size = 10 # 每个块10个字符 answer_chunks = [cached_answer[i:i+chunk_size] for i in range(0, len(cached_answer), chunk_size)] first_token_time = time.time() for i, chunk in enumerate(answer_chunks): chunk_time = time.time() # 首token延迟 if i == 0: latency = time.perf_counter() - start_time logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 🎯 首token到达(缓存),耗时: {latency:.3f}s") body = {'answer': chunk, 'chunk_index': i, 'total_chunks': len(answer_chunks)} yield format_sse(body) # 小延迟模拟打字效果(可调整或移除) # time.sleep(0.01) logger.info(f"[CACHE_STREAM {chunk_time:.3f}] 📝 缓存答案流式发送完成,共 {len(answer_chunks)} 个chunk") # 流结束 stream_end_time = time.time() end_message = format_sse({ 'status': 'completed', 'from_cache': True, 'source': cache_source, 'total_time': stream_end_time - server_stream_start }, event='end') yield end_message logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 🏁 缓存流式响应完成") if cached_answer: logger.info(f"[CACHE_STREAM {stream_end_time:.3f}] 📊 缓存答案长度: {len(cached_answer)} 字符") total_time = time.perf_counter() - start_time logger.info(f'[CACHE_STREAM {time.time():.3f}] ⏱️ Total Cache Stream Duration ({chat_id}): {total_time:.3f}s')