diff --git a/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py b/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py index 538d9ad..9bfac16 100644 --- a/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py +++ b/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py @@ -1,11 +1,12 @@ """ -简化的RAGFlow控制器 - 移除异步复杂性 +简化的RAGFlow控制器 - 集成语义缓存 """ +import asyncio import hashlib import json import time -from typing import Any, Dict, Generator +from typing import Any, Dict, Generator, Optional, Tuple from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse @@ -21,6 +22,18 @@ from module_admin.entity.vo.ragflow_vo import ( ) from utils.log_util import logger from utils.response_util import ResponseUtil +from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair + + +async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None: + """ + 异步存储问答对到语义缓存 + """ + try: + await store_qa_pair(chat_id, question, answer, redis) + except Exception as e: + logger.warning(f"[SemanticCache] 异步存储失败: {e}") + # 使用标准的APIRouter,增加认证依赖 ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)]) @@ -55,7 +68,7 @@ async def converse_with_chat_assistant( converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型 ): """ - 与聊天助手进行对话 - 同步版本,简化复杂性 + 与聊天助手进行对话 - 集成语义缓存版本 """ start_time = time.perf_counter() @@ -63,6 +76,19 @@ async def converse_with_chat_assistant( redis = getattr(request.app.state, 'redis', None) cache_key = None + # 语义缓存查找(非流式对话) + if not converse_params.stream and redis: + cache_result = await lookup_question( + converse_params.chat_id, + converse_params.question, + redis + ) + if cache_result: + cached_answer, similarity = cache_result + logger.info(f'[SemanticCache] 命中缓存 (相似度={similarity:.2f}): chat_id={converse_params.chat_id}') + # 返回缓存的答案,包装成标准响应格式 + return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': similarity}) + # 检查是否应该使用搜索服务 try: intent = await SearchService.classify_intent(converse_params.question) @@ -73,7 +99,7 @@ async def converse_with_chat_assistant( except Exception as e: logger.warning(f"意图分类失败,使用RAG服务: {e}") - # 缓存检查 + # 缓存检查(原有的精确缓存) if not converse_params.stream and redis: cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question) try: @@ -104,7 +130,23 @@ async def converse_with_chat_assistant( # 非流式响应 response = parse_result(result) - # 设置缓存 + # 设置语义缓存(存储问答对) + if not converse_params.stream and redis and isinstance(result, dict) and result.get('code') == 0: + try: + answer_data = result.get('data', {}) + answer_text = answer_data.get('answer') if isinstance(answer_data, dict) else str(answer_data) + if answer_text and len(answer_text.strip()) >= 10: + # 异步存储,不阻塞响应返回 + asyncio.create_task(_async_store_qa( + converse_params.chat_id, + converse_params.question, + answer_text, + redis + )) + except Exception as e: + logger.warning(f"语义缓存存储失败: {e}") + + # 设置原有精确缓存 if redis and cache_key and isinstance(result, dict) and result.get('code') == 0: try: await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60) diff --git a/ruoyi-fastapi-backend/module_admin/entity/vo/ragflow_vo.py b/ruoyi-fastapi-backend/module_admin/entity/vo/ragflow_vo.py index 7150481..dc3d149 100644 --- a/ruoyi-fastapi-backend/module_admin/entity/vo/ragflow_vo.py +++ b/ruoyi-fastapi-backend/module_admin/entity/vo/ragflow_vo.py @@ -136,7 +136,8 @@ class ConverseWithChatAssistantModel(BaseModel): 会话聊天模型 """ - model_config = ConfigDict(alias_generator=to_camel, from_attributes=True) + # 移除alias_generator,使用原始的snake_case参数名 + model_config = ConfigDict(from_attributes=True) chat_id: str = Field(default = None, description='会话ID') question: str = Field(default = None, description='问题') diff --git a/ruoyi-fastapi-backend/test/test_maxkb_performance.py b/ruoyi-fastapi-backend/test/test_maxkb_performance.py new file mode 100644 index 0000000..9cd60ad --- /dev/null +++ b/ruoyi-fastapi-backend/test/test_maxkb_performance.py @@ -0,0 +1,399 @@ +""" +MaxKB RAG 性能测试 - 测试首个token接收时间 +""" +import asyncio +import aiohttp +import time +import sys +import os +import json + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# MaxKB 配置 +MAXKB_API_URL = "http://10.0.0.202:8080" +API_KEY = "application-b758e71ecb8d4237f91795365b433c35" +CSRF_TOKEN = "bDzz45UzuDGB33YCykegwEG9QMBuLba3cisIDsL5Z4tsXTrntxZq8nhVJf7qT5Pq" + +# DeepSeek 配置 +DEEPSEEK_API_BASE = "https://api.deepseek.com" +DEEPSEEK_API_KEY = "sk-56b608b26a6949e4b09b5bf5f11c8f5b" +DEEPSEEK_MODEL = "deepseek-chat" + +# 测试问题 +TEST_QUESTION = "你好,请介绍一下你自己" + + +async def test_deepseek_first_token(question: str = TEST_QUESTION): + """测试 DeepSeek 直接调用的首个token接收时间""" + + print("=" * 70) + print("DeepSeek API 性能测试 - 首个Token接收时间") + print("=" * 70) + + url = f"{DEEPSEEK_API_BASE}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {DEEPSEEK_API_KEY}" + } + data = { + "model": DEEPSEEK_MODEL, + "messages": [ + {"role": "user", "content": question} + ], + "stream": True + } + + print(f"\n[1/2] 发送测试问题: \"{question}\"") + print(" 等待响应...") + + total_start_time = time.time() + first_token_time = None + token_count = 0 + first_token_received = False + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status != 200: + print(f" ✗ 请求失败: {response.status}") + error_text = await response.text() + print(f" 错误信息: {error_text}") + return + + print(f" ✓ 连接建立,开始接收数据...") + print(f"\n[2/2] 接收流式数据...") + print("-" * 70) + + async for chunk in response.content: + current_time = time.time() + elapsed = current_time - total_start_time + + # 解析 chunk 数据 + chunk_str = chunk.decode('utf-8') + + if not chunk_str.strip(): + continue + + # 记录首个token时间 + if not first_token_received and 'data:' in chunk_str: + first_token_time = elapsed + first_token_received = True + print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒") + print("-" * 70) + + # 解析并显示内容 + for line in chunk_str.split('\n'): + if line.startswith('data:'): + data_str = line[5:].strip() + if data_str == '[DONE]': + continue + + try: + data_json = json.loads(data_str) + # 从 choices[0].delta.content 获取实际内容(DeepSeek响应格式) + choices = data_json.get('choices', []) + if choices: + delta = choices[0].get('delta', {}) + content = delta.get('content', '') + + if content: + token_count += 1 + # 显示内容(前40个字符) + display_content = content[:40] + '...' if len(content) > 40 else content + print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}") + + except json.JSONDecodeError: + pass + + total_time = time.time() - total_start_time + + # 输出测试结果 + print("\n" + "=" * 70) + print("📊 DeepSeek API 性能测试结果") + print("=" * 70) + print(f" 测试问题: {question}") + print(f" 首个Token时间: {first_token_time:.4f} 秒" if first_token_time else " 首个Token时间: N/A") + print(f" 总响应时间: {total_time:.4f} 秒") + print(f" 接收Token数: {token_count}") + if first_token_time and total_time > 0: + print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒") + print("=" * 70) + + return first_token_time, total_time + + +async def get_chat_id(): + """获取一个新的 chat_id""" + url = f"{MAXKB_API_URL}/chat/api/open" + headers = { + "accept": "*/*", + "Authorization": f"Bearer {API_KEY}" + } + + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + # data 直接就是 chat_id 字符串 + return data.get("data") + else: + raise Exception(f"获取 chat_id 失败: {response.status}") + + +async def test_first_token_time(question: str = TEST_QUESTION): + """测试首个token的接收时间""" + + print("=" * 70) + print("MaxKB RAG 性能测试 - 首个Token接收时间") + print("=" * 70) + + # 1. 获取 chat_id + print("\n[1/3] 获取 chat_id...") + chat_id_start_time = time.time() + try: + chat_id = await get_chat_id() + chat_id_time = time.time() - chat_id_start_time + print(f" ✓ chat_id: {chat_id} ({chat_id_time:.4f}s)") + except Exception as e: + print(f" ✗ 获取 chat_id 失败: {e}") + return + + # 2. 发起流式对话请求 + print(f"\n[2/3] 发送测试问题: \"{question}\"") + print(" 等待响应...") + + url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}" + headers = { + "accept": "*/*", + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + "X-CSRFTOKEN": CSRF_TOKEN + } + data = { + "message": question, + "stream": True, + "re_chat": False + } + + total_start_time = time.time() + first_token_time = None + token_count = 0 + first_token_received = False + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status != 200: + print(f" ✗ 请求失败: {response.status}") + error_text = await response.text() + print(f" 错误信息: {error_text}") + return + + print(f" ✓ 连接建立,开始接收数据...") + print(f"\n[3/3] 接收流式数据...") + print("-" * 70) + + async for chunk in response.content: + current_time = time.time() + elapsed = current_time - total_start_time + + # 解析 chunk 数据 + chunk_str = chunk.decode('utf-8') + + if not chunk_str.strip(): + continue + + # 记录首个token时间 + if not first_token_received and chunk_str.startswith('data:'): + first_token_time = elapsed + first_token_received = True + print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒") + print("-" * 70) + + # 解析并显示内容 + for line in chunk_str.split('\n'): + if line.startswith('data:'): + data_str = line[5:].strip() + if data_str == '[DONE]': + continue + + try: + import json + data_json = json.loads(data_str) + # 从 content 字段获取实际内容(MaxKB响应格式) + content = data_json.get('content', '') + + if content: + if not first_token_received: + first_token_time = elapsed + first_token_received = True + print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒") + print("-" * 70) + + token_count += 1 + # 显示内容(前40个字符) + display_content = content[:40] + '...' if len(content) > 40 else content + print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}") + + except json.JSONDecodeError: + pass + + total_time = time.time() - total_start_time + + # 3. 输出测试结果 + print("\n" + "=" * 70) + print("📊 性能测试结果") + print("=" * 70) + print(f" 测试问题: {question}") + print(f" 首个Token时间: {first_token_time:.4f} 秒" if first_token_time else " 首个Token时间: N/A") + print(f" 总响应时间: {total_time:.4f} 秒") + print(f" 接收Token数: {token_count}") + if first_token_time and total_time > 0: + print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒") + print("=" * 70) + + return first_token_time, total_time + + +async def run_multiple_tests(n: int = 3, question: str = None): + """运行多次测试取平均值""" + if question is None: + question = TEST_QUESTION + + print(f"\n🔄 准备运行 {n} 次性能测试...") + print(f" 测试问题: {question}") + print() + + first_token_times = [] + total_times = [] + + for i in range(n): + print(f"\n{'='*70}") + print(f"测试 #{i + 1}/{n}") + print('='*70) + + # 重新获取 chat_id(每次测试使用新的会话) + chat_id_start_time = time.time() + chat_id = await get_chat_id() + chat_id_time = time.time() - chat_id_start_time + print(f" 获取chat_id: {chat_id_time:.4f}s") + + url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}" + headers = { + "accept": "*/*", + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + "X-CSRFTOKEN": CSRF_TOKEN + } + data = { + "message": question, + "stream": True, + "re_chat": False + } + + total_start_time = time.time() + first_token_time = None + token_count = 0 + first_token_received = False + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status != 200: + print(f"请求失败: {response.status}") + continue + + async for chunk in response.content: + current_time = time.time() + elapsed = current_time - total_start_time + + chunk_str = chunk.decode('utf-8') + if not chunk_str.strip(): + continue + + if not first_token_received and chunk_str.startswith('data:'): + first_token_time = elapsed + first_token_received = True + + if chunk_str.startswith('data:'): + data_str = chunk_str[5:].strip() + if data_str == '[DONE]': + continue + try: + import json + data_json = json.loads(data_str) + # 从 content 字段获取实际内容 + if data_json.get('content'): + if not first_token_received: + first_token_time = elapsed + first_token_received = True + token_count += 1 + except: + pass + + total_time = time.time() - total_start_time + + if first_token_time: + first_token_times.append(first_token_time) + total_times.append(total_time) + + print(f"\n 结果 #{i + 1}: 首Token={first_token_time:.4f}s, 总时间={total_time:.4f}s, Tokens={token_count}") + await asyncio.sleep(1) # 等待1秒再进行下次测试 + + # 计算平均值 + print(f"\n{'='*70}") + print("📈 多次测试统计结果") + print('='*70) + if first_token_times: + avg_first_token = sum(first_token_times) / len(first_token_times) + min_first_token = min(first_token_times) + max_first_token = max(first_token_times) + print(f" 首个Token时间:") + print(f" - 平均值: {avg_first_token:.4f} 秒") + print(f" - 最小值: {min_first_token:.4f} 秒") + print(f" - 最大值: {max_first_token:.4f} 秒") + + if total_times: + avg_total = sum(total_times) / len(total_times) + print(f" 总响应时间:") + print(f" - 平均值: {avg_total:.4f} 秒") + print('='*70) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='MaxKB RAG 性能测试') + parser.add_argument('question', nargs='?', default=TEST_QUESTION, help='测试问题(直接写在命令后面)') + parser.add_argument('--times', type=int, default=1, help='测试次数 (默认1次)') + parser.add_argument('--deepseek', action='store_true', help='同时测试 DeepSeek API 性能') + + args = parser.parse_args() + + if args.deepseek: + # 先测试 DeepSeek + print("\n" + "=" * 70) + print("🆚 性能对比测试") + print("=" * 70) + print("\n>>> 开始测试 DeepSeek API...") + ds_first, ds_total = asyncio.run(test_deepseek_first_token(args.question)) + + print("\n\n>>> 开始测试 MaxKB...") + maxkb_first, maxkb_total = asyncio.run(test_first_token_time(args.question)) + + # 对比结果 + print("\n" + "=" * 70) + print("📊 性能对比结果") + print("=" * 70) + print(f"\n {'服务':<15} {'首个Token':<15} {'总响应时间':<15}") + print(f" {'-'*45}") + print(f" {'DeepSeek':<15} {f'{ds_first:.4f}s' if ds_first else 'N/A':<15} {f'{ds_total:.4f}s' if ds_total else 'N/A':<15}") + print(f" {'MaxKB':<15} {f'{maxkb_first:.4f}s' if maxkb_first else 'N/A':<15} {f'{maxkb_total:.4f}s' if maxkb_total else 'N/A':<15}") + print(f" {'-'*45}") + if ds_first and maxkb_first: + diff = maxkb_first - ds_first + print(f"\n ⚡ MaxKB 比 DeepSeek 慢: {diff:.4f} 秒 ({diff/ds_first*100:.1f}%)") + print("=" * 70) + elif args.times == 1: + asyncio.run(test_first_token_time(args.question)) + else: + asyncio.run(run_multiple_tests(args.times, args.question)) diff --git a/ruoyi-fastapi-backend/test/test_semantic_cache.py b/ruoyi-fastapi-backend/test/test_semantic_cache.py new file mode 100644 index 0000000..7dbd509 --- /dev/null +++ b/ruoyi-fastapi-backend/test/test_semantic_cache.py @@ -0,0 +1,301 @@ +""" +测试语义缓存功能 + +功能: +1. 测试缓存的存储和查找 +2. 测试语义匹配效果 +3. 验证缓存性能 + +使用方式: +python test_semantic_cache.py + +作者:AI Assistant +""" + +import asyncio +import sys +import os +import time + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.semantic_cache_service import ( + SemanticCacheService, + get_semantic_cache_service, + lookup_question, + store_qa_pair, + clear_cache +) +from config.get_redis import RedisUtil + + +# 全局Redis客户端 +_redis_client = None + + +async def get_redis(): + """获取Redis客户端""" + global _redis_client + if _redis_client is None: + _redis_client = await RedisUtil.create_redis_pool() + return _redis_client + + +async def test_basic_functionality(): + """测试基本功能""" + print("\n" + "=" * 60) + print("测试1: 基本功能验证") + print("=" * 60) + + # 获取Redis客户端 + redis = await get_redis() + + # 创建缓存服务实例 + cache = SemanticCacheService(redis) + + # 清理测试缓存 + await cache.clear(chat_id="test_chat_001") + + # 测试1: 精确匹配 + print("\n1. 测试精确匹配...") + result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis) + print(f" 首次查找(未命中): {result}") + assert result is None, "首次查找应该未命中" + + # 存储问答对 + success = await cache.store( + "test_chat_001", + "康达公司成立于哪一年?", + "康达公司成立于2012年,是一家专注于智能机器人研发的高科技企业。", + redis + ) + print(f" 存储问答对: {'成功' if success else '失败'}") + + # 再次查找 + result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis) + print(f" 精确查找(已缓存): {'命中' if result else '未命中'}") + if result: + print(f" 答案: {result[0][:50]}...") + print(f" 相似度: {result[1]:.2f}") + assert result[1] == 1.0, "精确匹配相似度应该为1.0" + + print(" ✓ 精确匹配测试通过") + + +async def test_semantic_matching(): + """测试语义匹配""" + print("\n" + "=" * 60) + print("测试2: 语义匹配验证") + print("=" * 60) + + redis = await get_redis() + cache = SemanticCacheService(redis) + + # 清理并重新测试 + await cache.clear(chat_id="test_chat_002") + + # 存储原始问答 + original_question = "公司的退货政策是什么?" + original_answer = "我们支持7天无理由退换货,您可以在收到商品后7天内申请退换。" + await cache.store("test_chat_002", original_question, original_answer, redis) + + # 测试各种问法 + test_cases = [ + ("公司的退货政策是什么?", "原问题"), + ("退货政策是怎么样的?", "同义词替换"), + ("如果我想退货可以吗?", "语义相关"), + ("你们怎么退货?", "口语化表达"), + ("产品质量有问题怎么换?", "相关但不同"), + ("今天天气怎么样?", "完全无关"), + ] + + print("\n语义匹配测试结果:") + print("-" * 60) + for question, desc in test_cases: + result = await cache.lookup("test_chat_002", question, redis) + if result: + answer, similarity = result + print(f" [{desc}]") + print(f" Q: {question[:30]}...") + print(f" 命中! 相似度={similarity:.2f}") + else: + print(f" [{desc}]") + print(f" Q: {question[:30]}...") + print(f" 未命中") + + +async def test_performance(): + """测试缓存性能""" + print("\n" + "=" * 60) + print("测试3: 缓存性能验证") + print("=" * 60) + + redis = await get_redis() + cache = SemanticCacheService(redis) + + # 清理并准备测试数据 + await cache.clear(chat_id="test_chat_003") + + # 预热缓存 + questions = [ + "你们公司提供哪些服务?", + "如何联系客服?", + "产品售后政策是什么?", + "发货需要多长时间?", + "支持哪些支付方式?", + ] + + print("\n预热缓存...") + start_time = time.time() + for i, q in enumerate(questions): + await cache.store( + "test_chat_003", + q, + f"这是针对问题的回答 #{i+1}", + redis + ) + warm_time = time.time() - start_time + print(f" 存储5条问答耗时: {warm_time*1000:.2f}ms") + + # 性能测试:精确匹配 vs 语义匹配 + print("\n性能测试:") + + # 精确匹配测试 + start_time = time.time() + for _ in range(100): + await cache.lookup("test_chat_003", "你们公司提供哪些服务?", redis) + exact_time = time.time() - start_time + print(f" 100次精确匹配: {exact_time*1000:.2f}ms (平均{exact_time*10:.2f}ms/次)") + + # 语义匹配测试(会扫描所有缓存) + start_time = time.time() + for _ in range(10): + await cache.lookup("test_chat_003", "公司提供什么服务?", redis) + semantic_time = time.time() - start_time + print(f" 10次语义匹配: {semantic_time*1000:.2f}ms (平均{semantic_time*100:.2f}ms/次)") + + +async def test_cache_statistics(): + """测试缓存统计""" + print("\n" + "=" * 60) + print("测试4: 缓存统计功能") + print("=" * 60) + + redis = await get_redis() + cache = SemanticCacheService(redis) + + # 清理并准备数据 + await cache.clear(chat_id="test_chat_stats") + + # 存储一些问答 + await cache.store("test_chat_stats", "问题1", "答案1", redis) + await cache.store("test_chat_stats", "问题2", "答案2", redis) + + # 命中几次 + await cache.lookup("test_chat_stats", "问题1", redis) + await cache.lookup("test_chat_stats", "问题1", redis) + await cache.lookup("test_chat_stats", "问题2", redis) + + # 获取统计 + stats = await cache.get_stats("test_chat_stats", redis) + print(f"\n缓存统计信息:") + print(f" 总条目数: {stats.get('total_entries', 0)}") + print(f" 配置信息:") + print(f" - 最大缓存大小: {stats['config']['max_cache_size']}") + print(f" - 相似度阈值: {stats['config']['similarity_threshold']}") + print(f" - 缓存过期时间: {stats['config']['cache_ttl_hours']}小时") + + +async def test_integration_with_ragflow(): + """集成测试:模拟RAG流程""" + print("\n" + "=" * 60) + print("测试5: 集成测试 - 模拟RAG流程") + print("=" * 60) + + redis = await get_redis() + cache = SemanticCacheService(redis) + + # 模拟RAG服务 + class MockRAGService: + async def query(self, question): + # 模拟RAG耗时 + await asyncio.sleep(0.1) + return f"RAG生成的答案:{question}的详细回答..." + + rag_service = MockRAGService() + + # 清理测试缓存 + await cache.clear(chat_id="test_integration") + + test_questions = [ + "康达机器人有哪些产品?", + "康达机器人有哪些产品?", # 重复问题,期望命中缓存 + "你们的机器人产品有哪些?", # 相似问题,期望语义匹配 + "产品价格是多少?", # 新问题,不命中 + ] + + print("\n模拟用户连续提问:") + print("-" * 60) + + for i, question in enumerate(test_questions): + print(f"\n用户提问 #{i+1}: {question[:30]}...") + + # 1. 先查缓存 + cached = await cache.lookup("test_integration", question, redis) + + if cached: + answer, similarity = cached + print(f" ✓ 缓存命中! (相似度={similarity:.2f})") + print(f" → 直接返回缓存答案") + else: + print(f" ✗ 缓存未命中") + print(f" → 调用RAG服务...") + + # 2. 调用RAG + answer = await rag_service.query(question) + print(f" → RAG返回答案") + + # 3. 存储到缓存 + await cache.store("test_integration", question, answer, redis) + print(f" → 已存入缓存") + + print(f" 答案预览: {answer[:40]}...") + + +async def main(): + """主测试函数""" + print("\n" + "╔" + "═" * 58 + "╗") + print("║" + " " * 15 + "语义缓存功能测试" + " " * 20 + "║") + print("╚" + "═" * 58 + "╝") + + try: + await test_basic_functionality() + await test_semantic_matching() + await test_performance() + await test_cache_statistics() + await test_integration_with_ragflow() + + print("\n" + "=" * 60) + print("✓ 所有测试完成!") + print("=" * 60) + + except Exception as e: + print(f"\n✗ 测试失败: {e}") + import traceback + traceback.print_exc() + + finally: + # 清理测试数据 + redis = await get_redis() + cache = SemanticCacheService(redis) + await cache.clear(chat_id="test_chat_001") + await cache.clear(chat_id="test_chat_002") + await cache.clear(chat_id="test_chat_003") + await cache.clear(chat_id="test_chat_stats") + await cache.clear(chat_id="test_integration") + print("\n测试数据已清理") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/ruoyi-fastapi-backend/test/test_sse.py b/ruoyi-fastapi-backend/test/test_sse.py index 4d7b142..67bfbca 100644 --- a/ruoyi-fastapi-backend/test/test_sse.py +++ b/ruoyi-fastapi-backend/test/test_sse.py @@ -19,7 +19,7 @@ AUTH_TOKEN = os.environ.get( "Bearer ", ) DEFAULT_CHAT_ID = os.environ.get( - "RAGFLOW_CHAT_ID", "db4bb966895b11f08cda0242ac130006" + "RAGFLOW_CHAT_ID", "a455abcadbcb11f0884b0242ac130006" ) DEFAULT_SESSION_ID = os.environ.get( "RAGFLOW_SESSION_ID", "38d765e48a3811f0be310242ac130006" diff --git a/ruoyi-fastapi-backend/utils/semantic_cache_service.py b/ruoyi-fastapi-backend/utils/semantic_cache_service.py new file mode 100644 index 0000000..a6df828 --- /dev/null +++ b/ruoyi-fastapi-backend/utils/semantic_cache_service.py @@ -0,0 +1,472 @@ +""" +语义缓存服务 - 基于问答历史的缓存方案 + +功能: +1. 记录用户通过RAG返回的对话历史(问题+答案) +2. 当用户后续提问时,基于语义先搜索缓存 +3. 如果命中缓存,直接返回答案 +4. 如果未命中,调用RAG,并将结果存入缓存 + +技术方案: +- 使用Redis存储问答对(问题文本 + 答案文本) +- 使用简单的文本相似度算法(因为没有现成的Embedding模型) +- 缓存键:rag:cache:{chat_id}:{question_hash} +- 支持多级匹配:精确匹配 → 模糊匹配 + +作者:AI Assistant +日期:2024 +""" + +import hashlib +import json +import time +import re +import logging +from typing import Optional, Tuple, Dict, Any, List +from dataclasses import dataclass, asdict + +# 配置日志 +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """缓存条目""" + question: str + answer: str + created_at: float + hit_count: int = 0 + chat_id: str = "" + + def to_dict(self) -> dict: + return asdict(self) + + @classmethod + def from_dict(cls, data: dict) -> 'CacheEntry': + return cls(**data) + + +class SemanticCacheService: + """ + 语义缓存服务 + + 支持三种匹配策略: + 1. 精确匹配:问题完全相同 + 2. 关键语义匹配:提取问题核心关键词,判断语义相似 + 3. 模糊匹配:计算文本相似度 + """ + + # 缓存配置 + CACHE_PREFIX = "rag:semantic:cache" + MAX_CACHE_SIZE = 1000 # 每个chat_id最大缓存条数 + SIMILARITY_THRESHOLD = 0.75 # 相似度阈值 + CACHE_TTL_HOURS = 24 # 缓存过期时间(小时) + + def __init__(self, redis_client=None): + """ + 初始化语义缓存服务 + + Args: + redis_client: Redis客户端实例,如果为None则使用全局redis + """ + self.redis = redis_client + self._default_redis = None + + def _get_redis(self): + """获取Redis客户端""" + if self.redis is not None: + return self.redis + if self._default_redis is None: + from config.get_redis import RedisUtil + # 注意:这里需要异步获取,实际使用时应该传入redis实例 + logger.warning("未提供Redis客户端,语义缓存功能不可用") + return self._default_redis + + def _build_cache_key(self, chat_id: str, question_hash: str) -> str: + """构建缓存键""" + return f"{self.CACHE_PREFIX}:{chat_id}:{question_hash}" + + def _hash_question(self, question: str) -> str: + """对问题进行哈希,生成唯一标识""" + # 标准化问题文本(去除多余空格、统一标点) + normalized = self._normalize_question(question) + return hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16] + + def _normalize_question(self, question: str) -> str: + """标准化问题文本""" + # 去除首尾空格 + q = question.strip() + # 统一空格 + q = re.sub(r'\s+', ' ', q) + # 统一中文标点 + q = q.replace('?', '?').replace(',', ',').replace('。', '.') + return q + + def _extract_keywords(self, question: str) -> List[str]: + """ + 提取问题关键词(用于语义匹配) + + 策略: + 1. 保留完整问题作为主要匹配依据 + 2. 提取名词、动词等核心词汇 + 3. 过滤停用词 + """ + # 简单停用词列表 + stopwords = {'的', '是', '了', '在', '有', '和', '与', '或', '吗', '呢', '吧', '啊', '哦', '请问', '能不能', '可以', '怎么', '如何', '什么', '多少', '几个'} + + # 分词(简单按字符或词语切分) + words = re.findall(r'[\w\u4e00-\u9fff]+', question.lower()) + + # 过滤停用词和过短的词 + keywords = [w for w in words if w not in stopwords and len(w) > 1] + + return keywords + + def _calculate_text_similarity(self, q1: str, q2: str) -> float: + """ + 计算两个问题的文本相似度 + + 使用Jaccard相似度 + 关键词匹配的综合策略 + """ + # 标准化 + n1 = self._normalize_question(q1) + n2 = self._normalize_question(q2) + + # 如果完全相同,直接返回1.0 + if n1 == n2: + return 1.0 + + # 计算关键词重叠度 + kw1 = set(self._extract_keywords(q1)) + kw2 = set(self._extract_keywords(q2)) + + if not kw1 or not kw2: + return 0.0 + + # Jaccard相似度 + intersection = len(kw1 & kw2) + union = len(kw1 | kw2) + jaccard_sim = intersection / union if union > 0 else 0 + + # 长度相似度(惩罚长度差异过大的问题) + len_sim = 1 - abs(len(n1) - len(n2)) / max(len(n1), len(n2)) + len_sim = max(0, len_sim) # 确保非负 + + # 综合相似度(关键词权重0.7,长度权重0.3) + similarity = 0.7 * jaccard_sim + 0.3 * len_sim + + return similarity + + async def lookup( + self, + chat_id: str, + question: str, + redis_client=None + ) -> Optional[Tuple[str, float]]: + """ + 在缓存中查找相似问题 + + Args: + chat_id: 聊天会话ID + question: 用户问题 + redis_client: Redis客户端(可选) + + Returns: + 如果命中缓存,返回 (答案, 相似度) + 如果未命中,返回 None + """ + redis = redis_client or self._get_redis() + if redis is None: + return None + + try: + # 1. 精确匹配 + question_hash = self._hash_question(question) + exact_key = self._build_cache_key(chat_id, question_hash) + + cached_data = await redis.get(exact_key) + if cached_data: + entry = CacheEntry.from_dict(json.loads(cached_data)) + entry.hit_count += 1 + # 更新命中次数(异步,不阻塞返回) + await redis.set(exact_key, json.dumps(entry.to_dict()), ex=self.CACHE_TTL_HOURS * 3600) + logger.info(f"[SemanticCache] 精确命中: {question[:30]}...") + return entry.answer, 1.0 + + # 2. 语义模糊匹配(扫描同chat_id的所有缓存) + # 获取该chat_id的所有缓存键 + pattern = f"{self.CACHE_PREFIX}:{chat_id}:*" + cache_keys = await redis.keys(pattern) + + if not cache_keys: + return None + + # 遍历缓存,计算相似度 + best_match = None + best_similarity = 0.0 + + for key in cache_keys: + # 跳过刚检查过的精确匹配键 + if key == exact_key: + continue + + cached_data = await redis.get(key) + if not cached_data: + continue + + try: + entry = CacheEntry.from_dict(json.loads(cached_data)) + similarity = self._calculate_text_similarity(question, entry.question) + + if similarity > best_similarity: + best_similarity = similarity + best_match = entry + + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"[SemanticCache] 解析缓存失败: {key}, {e}") + continue + + # 判断是否达到相似度阈值 + if best_match and best_similarity >= self.SIMILARITY_THRESHOLD: + logger.info(f"[SemanticCache] 语义命中 (相似度={best_similarity:.2f}): {question[:30]}...") + # 更新命中次数 + best_match.hit_count += 1 + await redis.set(key, json.dumps(best_match.to_dict()), ex=self.CACHE_TTL_HOURS * 3600) + return best_match.answer, best_similarity + + return None + + except Exception as e: + logger.error(f"[SemanticCache] 查找失败: {e}") + return None + + async def store( + self, + chat_id: str, + question: str, + answer: str, + redis_client=None + ) -> bool: + """ + 将问答对存入缓存 + + Args: + chat_id: 聊天会话ID + question: 用户问题 + answer: RAG返回的答案 + redis_client: Redis客户端(可选) + + Returns: + 是否存储成功 + """ + redis = redis_client or self._get_redis() + if redis is None: + return False + + try: + # 过滤无效答案 + if not answer or len(answer.strip()) < 10: + logger.info(f"[SemanticCache] 答案太短,不缓存: {answer[:20]}...") + return False + + # 构建缓存条目 + question_hash = self._hash_question(question) + cache_key = self._build_cache_key(chat_id, question_hash) + + entry = CacheEntry( + question=question, + answer=answer, + created_at=time.time(), + hit_count=0, + chat_id=chat_id + ) + + # 检查是否已存在(避免重复存储) + existing = await redis.get(cache_key) + if existing: + logger.debug(f"[SemanticCache] 问题已存在,跳过: {question[:30]}...") + return True + + # 清理旧缓存(如果该chat_id缓存过多) + await self._cleanup_old_cache(chat_id, redis) + + # 存储缓存 + await redis.set( + cache_key, + json.dumps(entry.to_dict()), + ex=self.CACHE_TTL_HOURS * 3600 + ) + + logger.info(f"[SemanticCache] 已缓存: {question[:30]}...") + return True + + except Exception as e: + logger.error(f"[SemanticCache] 存储失败: {e}") + return False + + async def _cleanup_old_cache(self, chat_id: str, redis_client=None) -> int: + """ + 清理旧缓存(当缓存过多时) + + 策略:删除最久未命中(创建时间最早)的缓存 + + Returns: + 删除的缓存数量 + """ + redis = redis_client or self._get_redis() + if redis is None: + return 0 + + try: + pattern = f"{self.CACHE_PREFIX}:{chat_id}:*" + cache_keys = await redis.keys(pattern) + + if len(cache_keys) <= self.MAX_CACHE_SIZE: + return 0 + + # 获取所有缓存的创建时间 + cache_info = [] + for key in cache_keys: + data = await redis.get(key) + if data: + try: + entry = CacheEntry.from_dict(json.loads(data)) + cache_info.append((key, entry.created_at)) + except: + pass + + # 按创建时间排序,删除最旧的 + cache_info.sort(key=lambda x: x[1]) + delete_count = len(cache_keys) - self.MAX_CACHE_SIZE + + deleted = 0 + for key, _ in cache_info[:delete_count]: + await redis.delete(key) + deleted += 1 + + if deleted > 0: + logger.info(f"[SemanticCache] 清理了 {deleted} 条旧缓存") + + return deleted + + except Exception as e: + logger.error(f"[SemanticCache] 清理失败: {e}") + return 0 + + async def clear(self, chat_id: str = None, redis_client=None) -> bool: + """ + 清除缓存 + + Args: + chat_id: 要清除的chat_id,如果为None则清除所有 + redis_client: Redis客户端(可选) + + Returns: + 是否清除成功 + """ + redis = redis_client or self._get_redis() + if redis is None: + return False + + try: + if chat_id: + pattern = f"{self.CACHE_PREFIX}:{chat_id}:*" + else: + pattern = f"{self.CACHE_PREFIX}:*" + + cache_keys = await redis.keys(pattern) + + if cache_keys: + await redis.delete(*cache_keys) + logger.info(f"[SemanticCache] 清除 {len(cache_keys)} 条缓存") + + return True + + except Exception as e: + logger.error(f"[SemanticCache] 清除失败: {e}") + return False + + async def get_stats(self, chat_id: str = None, redis_client=None) -> Dict[str, Any]: + """ + 获取缓存统计信息 + + Args: + chat_id: 要统计的chat_id,如果为None则统计所有 + redis_client: Redis客户端(可选) + + Returns: + 统计信息字典 + """ + redis = redis_client or self._get_redis() + if redis is None: + return {} + + try: + if chat_id: + pattern = f"{self.CACHE_PREFIX}:{chat_id}:*" + else: + pattern = f"{self.CACHE_PREFIX}:*" + + cache_keys = await redis.keys(pattern) + + stats = { + "total_entries": len(cache_keys), + "chat_id": chat_id or "all", + "config": { + "max_cache_size": self.MAX_CACHE_SIZE, + "similarity_threshold": self.SIMILARITY_THRESHOLD, + "cache_ttl_hours": self.CACHE_TTL_HOURS + } + } + + # 统计命中次数分布 + if cache_keys: + total_hits = 0 + for key in cache_keys[:100]: # 只采样前100个 + data = await redis.get(key) + if data: + try: + entry = CacheEntry.from_dict(json.loads(data)) + total_hits += entry.hit_count + except: + pass + + stats["total_hits"] = total_hits + stats["avg_hits_per_entry"] = total_hits / min(len(cache_keys), 100) + + return stats + + except Exception as e: + logger.error(f"[SemanticCache] 统计失败: {e}") + return {} + + +# 全局服务实例 +_semantic_cache_service: Optional[SemanticCacheService] = None + + +def get_semantic_cache_service() -> SemanticCacheService: + """获取语义缓存服务单例""" + global _semantic_cache_service + if _semantic_cache_service is None: + _semantic_cache_service = SemanticCacheService() + return _semantic_cache_service + + +# 便捷函数 +async def lookup_question(chat_id: str, question: str, redis_client=None) -> Optional[Tuple[str, float]]: + """查找缓存问题""" + service = get_semantic_cache_service() + return await service.lookup(chat_id, question, redis_client) + + +async def store_qa_pair(chat_id: str, question: str, answer: str, redis_client=None) -> bool: + """存储问答对""" + service = get_semantic_cache_service() + return await service.store(chat_id, question, answer, redis_client) + + +async def clear_cache(chat_id: str = None, redis_client=None) -> bool: + """清除缓存""" + service = get_semantic_cache_service() + return await service.clear(chat_id, redis_client)