增加问答缓存功能;增加对 maxkb 的测试
This commit is contained in:
parent
10ca1e2c00
commit
5d15567460
@ -1,11 +1,12 @@
|
||||
"""
|
||||
简化的RAGFlow控制器 - 移除异步复杂性
|
||||
简化的RAGFlow控制器 - 集成语义缓存
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Generator
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -21,6 +22,18 @@ from module_admin.entity.vo.ragflow_vo import (
|
||||
)
|
||||
from utils.log_util import logger
|
||||
from utils.response_util import ResponseUtil
|
||||
from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair
|
||||
|
||||
|
||||
async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None:
|
||||
"""
|
||||
异步存储问答对到语义缓存
|
||||
"""
|
||||
try:
|
||||
await store_qa_pair(chat_id, question, answer, redis)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticCache] 异步存储失败: {e}")
|
||||
|
||||
|
||||
# 使用标准的APIRouter,增加认证依赖
|
||||
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||||
@ -55,7 +68,7 @@ async def converse_with_chat_assistant(
|
||||
converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型
|
||||
):
|
||||
"""
|
||||
与聊天助手进行对话 - 同步版本,简化复杂性
|
||||
与聊天助手进行对话 - 集成语义缓存版本
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
@ -63,6 +76,19 @@ async def converse_with_chat_assistant(
|
||||
redis = getattr(request.app.state, 'redis', None)
|
||||
cache_key = None
|
||||
|
||||
# 语义缓存查找(非流式对话)
|
||||
if not converse_params.stream and redis:
|
||||
cache_result = await lookup_question(
|
||||
converse_params.chat_id,
|
||||
converse_params.question,
|
||||
redis
|
||||
)
|
||||
if cache_result:
|
||||
cached_answer, similarity = cache_result
|
||||
logger.info(f'[SemanticCache] 命中缓存 (相似度={similarity:.2f}): chat_id={converse_params.chat_id}')
|
||||
# 返回缓存的答案,包装成标准响应格式
|
||||
return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': similarity})
|
||||
|
||||
# 检查是否应该使用搜索服务
|
||||
try:
|
||||
intent = await SearchService.classify_intent(converse_params.question)
|
||||
@ -73,7 +99,7 @@ async def converse_with_chat_assistant(
|
||||
except Exception as e:
|
||||
logger.warning(f"意图分类失败,使用RAG服务: {e}")
|
||||
|
||||
# 缓存检查
|
||||
# 缓存检查(原有的精确缓存)
|
||||
if not converse_params.stream and redis:
|
||||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||||
try:
|
||||
@ -104,7 +130,23 @@ async def converse_with_chat_assistant(
|
||||
# 非流式响应
|
||||
response = parse_result(result)
|
||||
|
||||
# 设置缓存
|
||||
# 设置语义缓存(存储问答对)
|
||||
if not converse_params.stream and redis and isinstance(result, dict) and result.get('code') == 0:
|
||||
try:
|
||||
answer_data = result.get('data', {})
|
||||
answer_text = answer_data.get('answer') if isinstance(answer_data, dict) else str(answer_data)
|
||||
if answer_text and len(answer_text.strip()) >= 10:
|
||||
# 异步存储,不阻塞响应返回
|
||||
asyncio.create_task(_async_store_qa(
|
||||
converse_params.chat_id,
|
||||
converse_params.question,
|
||||
answer_text,
|
||||
redis
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"语义缓存存储失败: {e}")
|
||||
|
||||
# 设置原有精确缓存
|
||||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||||
try:
|
||||
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)
|
||||
|
||||
@ -136,7 +136,8 @@ class ConverseWithChatAssistantModel(BaseModel):
|
||||
会话聊天模型
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(alias_generator=to_camel, from_attributes=True)
|
||||
# 移除alias_generator,使用原始的snake_case参数名
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
chat_id: str = Field(default = None, description='会话ID')
|
||||
question: str = Field(default = None, description='问题')
|
||||
|
||||
399
ruoyi-fastapi-backend/test/test_maxkb_performance.py
Normal file
399
ruoyi-fastapi-backend/test/test_maxkb_performance.py
Normal file
@ -0,0 +1,399 @@
|
||||
"""
|
||||
MaxKB RAG 性能测试 - 测试首个token接收时间
|
||||
"""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# MaxKB 配置
|
||||
MAXKB_API_URL = "http://10.0.0.202:8080"
|
||||
API_KEY = "application-b758e71ecb8d4237f91795365b433c35"
|
||||
CSRF_TOKEN = "bDzz45UzuDGB33YCykegwEG9QMBuLba3cisIDsL5Z4tsXTrntxZq8nhVJf7qT5Pq"
|
||||
|
||||
# DeepSeek 配置
|
||||
DEEPSEEK_API_BASE = "https://api.deepseek.com"
|
||||
DEEPSEEK_API_KEY = "sk-56b608b26a6949e4b09b5bf5f11c8f5b"
|
||||
DEEPSEEK_MODEL = "deepseek-chat"
|
||||
|
||||
# 测试问题
|
||||
TEST_QUESTION = "你好,请介绍一下你自己"
|
||||
|
||||
|
||||
async def test_deepseek_first_token(question: str = TEST_QUESTION):
|
||||
"""测试 DeepSeek 直接调用的首个token接收时间"""
|
||||
|
||||
print("=" * 70)
|
||||
print("DeepSeek API 性能测试 - 首个Token接收时间")
|
||||
print("=" * 70)
|
||||
|
||||
url = f"{DEEPSEEK_API_BASE}/chat/completions"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {DEEPSEEK_API_KEY}"
|
||||
}
|
||||
data = {
|
||||
"model": DEEPSEEK_MODEL,
|
||||
"messages": [
|
||||
{"role": "user", "content": question}
|
||||
],
|
||||
"stream": True
|
||||
}
|
||||
|
||||
print(f"\n[1/2] 发送测试问题: \"{question}\"")
|
||||
print(" 等待响应...")
|
||||
|
||||
total_start_time = time.time()
|
||||
first_token_time = None
|
||||
token_count = 0
|
||||
first_token_received = False
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
print(f" ✗ 请求失败: {response.status}")
|
||||
error_text = await response.text()
|
||||
print(f" 错误信息: {error_text}")
|
||||
return
|
||||
|
||||
print(f" ✓ 连接建立,开始接收数据...")
|
||||
print(f"\n[2/2] 接收流式数据...")
|
||||
print("-" * 70)
|
||||
|
||||
async for chunk in response.content:
|
||||
current_time = time.time()
|
||||
elapsed = current_time - total_start_time
|
||||
|
||||
# 解析 chunk 数据
|
||||
chunk_str = chunk.decode('utf-8')
|
||||
|
||||
if not chunk_str.strip():
|
||||
continue
|
||||
|
||||
# 记录首个token时间
|
||||
if not first_token_received and 'data:' in chunk_str:
|
||||
first_token_time = elapsed
|
||||
first_token_received = True
|
||||
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒")
|
||||
print("-" * 70)
|
||||
|
||||
# 解析并显示内容
|
||||
for line in chunk_str.split('\n'):
|
||||
if line.startswith('data:'):
|
||||
data_str = line[5:].strip()
|
||||
if data_str == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
data_json = json.loads(data_str)
|
||||
# 从 choices[0].delta.content 获取实际内容(DeepSeek响应格式)
|
||||
choices = data_json.get('choices', [])
|
||||
if choices:
|
||||
delta = choices[0].get('delta', {})
|
||||
content = delta.get('content', '')
|
||||
|
||||
if content:
|
||||
token_count += 1
|
||||
# 显示内容(前40个字符)
|
||||
display_content = content[:40] + '...' if len(content) > 40 else content
|
||||
print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
# 输出测试结果
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 DeepSeek API 性能测试结果")
|
||||
print("=" * 70)
|
||||
print(f" 测试问题: {question}")
|
||||
print(f" 首个Token时间: {first_token_time:.4f} 秒" if first_token_time else " 首个Token时间: N/A")
|
||||
print(f" 总响应时间: {total_time:.4f} 秒")
|
||||
print(f" 接收Token数: {token_count}")
|
||||
if first_token_time and total_time > 0:
|
||||
print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒")
|
||||
print("=" * 70)
|
||||
|
||||
return first_token_time, total_time
|
||||
|
||||
|
||||
async def get_chat_id():
|
||||
"""获取一个新的 chat_id"""
|
||||
url = f"{MAXKB_API_URL}/chat/api/open"
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
# data 直接就是 chat_id 字符串
|
||||
return data.get("data")
|
||||
else:
|
||||
raise Exception(f"获取 chat_id 失败: {response.status}")
|
||||
|
||||
|
||||
async def test_first_token_time(question: str = TEST_QUESTION):
|
||||
"""测试首个token的接收时间"""
|
||||
|
||||
print("=" * 70)
|
||||
print("MaxKB RAG 性能测试 - 首个Token接收时间")
|
||||
print("=" * 70)
|
||||
|
||||
# 1. 获取 chat_id
|
||||
print("\n[1/3] 获取 chat_id...")
|
||||
chat_id_start_time = time.time()
|
||||
try:
|
||||
chat_id = await get_chat_id()
|
||||
chat_id_time = time.time() - chat_id_start_time
|
||||
print(f" ✓ chat_id: {chat_id} ({chat_id_time:.4f}s)")
|
||||
except Exception as e:
|
||||
print(f" ✗ 获取 chat_id 失败: {e}")
|
||||
return
|
||||
|
||||
# 2. 发起流式对话请求
|
||||
print(f"\n[2/3] 发送测试问题: \"{question}\"")
|
||||
print(" 等待响应...")
|
||||
|
||||
url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}"
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
"X-CSRFTOKEN": CSRF_TOKEN
|
||||
}
|
||||
data = {
|
||||
"message": question,
|
||||
"stream": True,
|
||||
"re_chat": False
|
||||
}
|
||||
|
||||
total_start_time = time.time()
|
||||
first_token_time = None
|
||||
token_count = 0
|
||||
first_token_received = False
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
print(f" ✗ 请求失败: {response.status}")
|
||||
error_text = await response.text()
|
||||
print(f" 错误信息: {error_text}")
|
||||
return
|
||||
|
||||
print(f" ✓ 连接建立,开始接收数据...")
|
||||
print(f"\n[3/3] 接收流式数据...")
|
||||
print("-" * 70)
|
||||
|
||||
async for chunk in response.content:
|
||||
current_time = time.time()
|
||||
elapsed = current_time - total_start_time
|
||||
|
||||
# 解析 chunk 数据
|
||||
chunk_str = chunk.decode('utf-8')
|
||||
|
||||
if not chunk_str.strip():
|
||||
continue
|
||||
|
||||
# 记录首个token时间
|
||||
if not first_token_received and chunk_str.startswith('data:'):
|
||||
first_token_time = elapsed
|
||||
first_token_received = True
|
||||
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒")
|
||||
print("-" * 70)
|
||||
|
||||
# 解析并显示内容
|
||||
for line in chunk_str.split('\n'):
|
||||
if line.startswith('data:'):
|
||||
data_str = line[5:].strip()
|
||||
if data_str == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
import json
|
||||
data_json = json.loads(data_str)
|
||||
# 从 content 字段获取实际内容(MaxKB响应格式)
|
||||
content = data_json.get('content', '')
|
||||
|
||||
if content:
|
||||
if not first_token_received:
|
||||
first_token_time = elapsed
|
||||
first_token_received = True
|
||||
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f} 秒")
|
||||
print("-" * 70)
|
||||
|
||||
token_count += 1
|
||||
# 显示内容(前40个字符)
|
||||
display_content = content[:40] + '...' if len(content) > 40 else content
|
||||
print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
# 3. 输出测试结果
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 性能测试结果")
|
||||
print("=" * 70)
|
||||
print(f" 测试问题: {question}")
|
||||
print(f" 首个Token时间: {first_token_time:.4f} 秒" if first_token_time else " 首个Token时间: N/A")
|
||||
print(f" 总响应时间: {total_time:.4f} 秒")
|
||||
print(f" 接收Token数: {token_count}")
|
||||
if first_token_time and total_time > 0:
|
||||
print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒")
|
||||
print("=" * 70)
|
||||
|
||||
return first_token_time, total_time
|
||||
|
||||
|
||||
async def run_multiple_tests(n: int = 3, question: str = None):
|
||||
"""运行多次测试取平均值"""
|
||||
if question is None:
|
||||
question = TEST_QUESTION
|
||||
|
||||
print(f"\n🔄 准备运行 {n} 次性能测试...")
|
||||
print(f" 测试问题: {question}")
|
||||
print()
|
||||
|
||||
first_token_times = []
|
||||
total_times = []
|
||||
|
||||
for i in range(n):
|
||||
print(f"\n{'='*70}")
|
||||
print(f"测试 #{i + 1}/{n}")
|
||||
print('='*70)
|
||||
|
||||
# 重新获取 chat_id(每次测试使用新的会话)
|
||||
chat_id_start_time = time.time()
|
||||
chat_id = await get_chat_id()
|
||||
chat_id_time = time.time() - chat_id_start_time
|
||||
print(f" 获取chat_id: {chat_id_time:.4f}s")
|
||||
|
||||
url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}"
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
"X-CSRFTOKEN": CSRF_TOKEN
|
||||
}
|
||||
data = {
|
||||
"message": question,
|
||||
"stream": True,
|
||||
"re_chat": False
|
||||
}
|
||||
|
||||
total_start_time = time.time()
|
||||
first_token_time = None
|
||||
token_count = 0
|
||||
first_token_received = False
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
print(f"请求失败: {response.status}")
|
||||
continue
|
||||
|
||||
async for chunk in response.content:
|
||||
current_time = time.time()
|
||||
elapsed = current_time - total_start_time
|
||||
|
||||
chunk_str = chunk.decode('utf-8')
|
||||
if not chunk_str.strip():
|
||||
continue
|
||||
|
||||
if not first_token_received and chunk_str.startswith('data:'):
|
||||
first_token_time = elapsed
|
||||
first_token_received = True
|
||||
|
||||
if chunk_str.startswith('data:'):
|
||||
data_str = chunk_str[5:].strip()
|
||||
if data_str == '[DONE]':
|
||||
continue
|
||||
try:
|
||||
import json
|
||||
data_json = json.loads(data_str)
|
||||
# 从 content 字段获取实际内容
|
||||
if data_json.get('content'):
|
||||
if not first_token_received:
|
||||
first_token_time = elapsed
|
||||
first_token_received = True
|
||||
token_count += 1
|
||||
except:
|
||||
pass
|
||||
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
if first_token_time:
|
||||
first_token_times.append(first_token_time)
|
||||
total_times.append(total_time)
|
||||
|
||||
print(f"\n 结果 #{i + 1}: 首Token={first_token_time:.4f}s, 总时间={total_time:.4f}s, Tokens={token_count}")
|
||||
await asyncio.sleep(1) # 等待1秒再进行下次测试
|
||||
|
||||
# 计算平均值
|
||||
print(f"\n{'='*70}")
|
||||
print("📈 多次测试统计结果")
|
||||
print('='*70)
|
||||
if first_token_times:
|
||||
avg_first_token = sum(first_token_times) / len(first_token_times)
|
||||
min_first_token = min(first_token_times)
|
||||
max_first_token = max(first_token_times)
|
||||
print(f" 首个Token时间:")
|
||||
print(f" - 平均值: {avg_first_token:.4f} 秒")
|
||||
print(f" - 最小值: {min_first_token:.4f} 秒")
|
||||
print(f" - 最大值: {max_first_token:.4f} 秒")
|
||||
|
||||
if total_times:
|
||||
avg_total = sum(total_times) / len(total_times)
|
||||
print(f" 总响应时间:")
|
||||
print(f" - 平均值: {avg_total:.4f} 秒")
|
||||
print('='*70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='MaxKB RAG 性能测试')
|
||||
parser.add_argument('question', nargs='?', default=TEST_QUESTION, help='测试问题(直接写在命令后面)')
|
||||
parser.add_argument('--times', type=int, default=1, help='测试次数 (默认1次)')
|
||||
parser.add_argument('--deepseek', action='store_true', help='同时测试 DeepSeek API 性能')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.deepseek:
|
||||
# 先测试 DeepSeek
|
||||
print("\n" + "=" * 70)
|
||||
print("🆚 性能对比测试")
|
||||
print("=" * 70)
|
||||
print("\n>>> 开始测试 DeepSeek API...")
|
||||
ds_first, ds_total = asyncio.run(test_deepseek_first_token(args.question))
|
||||
|
||||
print("\n\n>>> 开始测试 MaxKB...")
|
||||
maxkb_first, maxkb_total = asyncio.run(test_first_token_time(args.question))
|
||||
|
||||
# 对比结果
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 性能对比结果")
|
||||
print("=" * 70)
|
||||
print(f"\n {'服务':<15} {'首个Token':<15} {'总响应时间':<15}")
|
||||
print(f" {'-'*45}")
|
||||
print(f" {'DeepSeek':<15} {f'{ds_first:.4f}s' if ds_first else 'N/A':<15} {f'{ds_total:.4f}s' if ds_total else 'N/A':<15}")
|
||||
print(f" {'MaxKB':<15} {f'{maxkb_first:.4f}s' if maxkb_first else 'N/A':<15} {f'{maxkb_total:.4f}s' if maxkb_total else 'N/A':<15}")
|
||||
print(f" {'-'*45}")
|
||||
if ds_first and maxkb_first:
|
||||
diff = maxkb_first - ds_first
|
||||
print(f"\n ⚡ MaxKB 比 DeepSeek 慢: {diff:.4f} 秒 ({diff/ds_first*100:.1f}%)")
|
||||
print("=" * 70)
|
||||
elif args.times == 1:
|
||||
asyncio.run(test_first_token_time(args.question))
|
||||
else:
|
||||
asyncio.run(run_multiple_tests(args.times, args.question))
|
||||
301
ruoyi-fastapi-backend/test/test_semantic_cache.py
Normal file
301
ruoyi-fastapi-backend/test/test_semantic_cache.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""
|
||||
测试语义缓存功能
|
||||
|
||||
功能:
|
||||
1. 测试缓存的存储和查找
|
||||
2. 测试语义匹配效果
|
||||
3. 验证缓存性能
|
||||
|
||||
使用方式:
|
||||
python test_semantic_cache.py
|
||||
|
||||
作者:AI Assistant
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils.semantic_cache_service import (
|
||||
SemanticCacheService,
|
||||
get_semantic_cache_service,
|
||||
lookup_question,
|
||||
store_qa_pair,
|
||||
clear_cache
|
||||
)
|
||||
from config.get_redis import RedisUtil
|
||||
|
||||
|
||||
# 全局Redis客户端
|
||||
_redis_client = None
|
||||
|
||||
|
||||
async def get_redis():
|
||||
"""获取Redis客户端"""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = await RedisUtil.create_redis_pool()
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def test_basic_functionality():
|
||||
"""测试基本功能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试1: 基本功能验证")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取Redis客户端
|
||||
redis = await get_redis()
|
||||
|
||||
# 创建缓存服务实例
|
||||
cache = SemanticCacheService(redis)
|
||||
|
||||
# 清理测试缓存
|
||||
await cache.clear(chat_id="test_chat_001")
|
||||
|
||||
# 测试1: 精确匹配
|
||||
print("\n1. 测试精确匹配...")
|
||||
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
|
||||
print(f" 首次查找(未命中): {result}")
|
||||
assert result is None, "首次查找应该未命中"
|
||||
|
||||
# 存储问答对
|
||||
success = await cache.store(
|
||||
"test_chat_001",
|
||||
"康达公司成立于哪一年?",
|
||||
"康达公司成立于2012年,是一家专注于智能机器人研发的高科技企业。",
|
||||
redis
|
||||
)
|
||||
print(f" 存储问答对: {'成功' if success else '失败'}")
|
||||
|
||||
# 再次查找
|
||||
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
|
||||
print(f" 精确查找(已缓存): {'命中' if result else '未命中'}")
|
||||
if result:
|
||||
print(f" 答案: {result[0][:50]}...")
|
||||
print(f" 相似度: {result[1]:.2f}")
|
||||
assert result[1] == 1.0, "精确匹配相似度应该为1.0"
|
||||
|
||||
print(" ✓ 精确匹配测试通过")
|
||||
|
||||
|
||||
async def test_semantic_matching():
|
||||
"""测试语义匹配"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试2: 语义匹配验证")
|
||||
print("=" * 60)
|
||||
|
||||
redis = await get_redis()
|
||||
cache = SemanticCacheService(redis)
|
||||
|
||||
# 清理并重新测试
|
||||
await cache.clear(chat_id="test_chat_002")
|
||||
|
||||
# 存储原始问答
|
||||
original_question = "公司的退货政策是什么?"
|
||||
original_answer = "我们支持7天无理由退换货,您可以在收到商品后7天内申请退换。"
|
||||
await cache.store("test_chat_002", original_question, original_answer, redis)
|
||||
|
||||
# 测试各种问法
|
||||
test_cases = [
|
||||
("公司的退货政策是什么?", "原问题"),
|
||||
("退货政策是怎么样的?", "同义词替换"),
|
||||
("如果我想退货可以吗?", "语义相关"),
|
||||
("你们怎么退货?", "口语化表达"),
|
||||
("产品质量有问题怎么换?", "相关但不同"),
|
||||
("今天天气怎么样?", "完全无关"),
|
||||
]
|
||||
|
||||
print("\n语义匹配测试结果:")
|
||||
print("-" * 60)
|
||||
for question, desc in test_cases:
|
||||
result = await cache.lookup("test_chat_002", question, redis)
|
||||
if result:
|
||||
answer, similarity = result
|
||||
print(f" [{desc}]")
|
||||
print(f" Q: {question[:30]}...")
|
||||
print(f" 命中! 相似度={similarity:.2f}")
|
||||
else:
|
||||
print(f" [{desc}]")
|
||||
print(f" Q: {question[:30]}...")
|
||||
print(f" 未命中")
|
||||
|
||||
|
||||
async def test_performance():
|
||||
"""测试缓存性能"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试3: 缓存性能验证")
|
||||
print("=" * 60)
|
||||
|
||||
redis = await get_redis()
|
||||
cache = SemanticCacheService(redis)
|
||||
|
||||
# 清理并准备测试数据
|
||||
await cache.clear(chat_id="test_chat_003")
|
||||
|
||||
# 预热缓存
|
||||
questions = [
|
||||
"你们公司提供哪些服务?",
|
||||
"如何联系客服?",
|
||||
"产品售后政策是什么?",
|
||||
"发货需要多长时间?",
|
||||
"支持哪些支付方式?",
|
||||
]
|
||||
|
||||
print("\n预热缓存...")
|
||||
start_time = time.time()
|
||||
for i, q in enumerate(questions):
|
||||
await cache.store(
|
||||
"test_chat_003",
|
||||
q,
|
||||
f"这是针对问题的回答 #{i+1}",
|
||||
redis
|
||||
)
|
||||
warm_time = time.time() - start_time
|
||||
print(f" 存储5条问答耗时: {warm_time*1000:.2f}ms")
|
||||
|
||||
# 性能测试:精确匹配 vs 语义匹配
|
||||
print("\n性能测试:")
|
||||
|
||||
# 精确匹配测试
|
||||
start_time = time.time()
|
||||
for _ in range(100):
|
||||
await cache.lookup("test_chat_003", "你们公司提供哪些服务?", redis)
|
||||
exact_time = time.time() - start_time
|
||||
print(f" 100次精确匹配: {exact_time*1000:.2f}ms (平均{exact_time*10:.2f}ms/次)")
|
||||
|
||||
# 语义匹配测试(会扫描所有缓存)
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
await cache.lookup("test_chat_003", "公司提供什么服务?", redis)
|
||||
semantic_time = time.time() - start_time
|
||||
print(f" 10次语义匹配: {semantic_time*1000:.2f}ms (平均{semantic_time*100:.2f}ms/次)")
|
||||
|
||||
|
||||
async def test_cache_statistics():
|
||||
"""测试缓存统计"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试4: 缓存统计功能")
|
||||
print("=" * 60)
|
||||
|
||||
redis = await get_redis()
|
||||
cache = SemanticCacheService(redis)
|
||||
|
||||
# 清理并准备数据
|
||||
await cache.clear(chat_id="test_chat_stats")
|
||||
|
||||
# 存储一些问答
|
||||
await cache.store("test_chat_stats", "问题1", "答案1", redis)
|
||||
await cache.store("test_chat_stats", "问题2", "答案2", redis)
|
||||
|
||||
# 命中几次
|
||||
await cache.lookup("test_chat_stats", "问题1", redis)
|
||||
await cache.lookup("test_chat_stats", "问题1", redis)
|
||||
await cache.lookup("test_chat_stats", "问题2", redis)
|
||||
|
||||
# 获取统计
|
||||
stats = await cache.get_stats("test_chat_stats", redis)
|
||||
print(f"\n缓存统计信息:")
|
||||
print(f" 总条目数: {stats.get('total_entries', 0)}")
|
||||
print(f" 配置信息:")
|
||||
print(f" - 最大缓存大小: {stats['config']['max_cache_size']}")
|
||||
print(f" - 相似度阈值: {stats['config']['similarity_threshold']}")
|
||||
print(f" - 缓存过期时间: {stats['config']['cache_ttl_hours']}小时")
|
||||
|
||||
|
||||
async def test_integration_with_ragflow():
|
||||
"""集成测试:模拟RAG流程"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试5: 集成测试 - 模拟RAG流程")
|
||||
print("=" * 60)
|
||||
|
||||
redis = await get_redis()
|
||||
cache = SemanticCacheService(redis)
|
||||
|
||||
# 模拟RAG服务
|
||||
class MockRAGService:
|
||||
async def query(self, question):
|
||||
# 模拟RAG耗时
|
||||
await asyncio.sleep(0.1)
|
||||
return f"RAG生成的答案:{question}的详细回答..."
|
||||
|
||||
rag_service = MockRAGService()
|
||||
|
||||
# 清理测试缓存
|
||||
await cache.clear(chat_id="test_integration")
|
||||
|
||||
test_questions = [
|
||||
"康达机器人有哪些产品?",
|
||||
"康达机器人有哪些产品?", # 重复问题,期望命中缓存
|
||||
"你们的机器人产品有哪些?", # 相似问题,期望语义匹配
|
||||
"产品价格是多少?", # 新问题,不命中
|
||||
]
|
||||
|
||||
print("\n模拟用户连续提问:")
|
||||
print("-" * 60)
|
||||
|
||||
for i, question in enumerate(test_questions):
|
||||
print(f"\n用户提问 #{i+1}: {question[:30]}...")
|
||||
|
||||
# 1. 先查缓存
|
||||
cached = await cache.lookup("test_integration", question, redis)
|
||||
|
||||
if cached:
|
||||
answer, similarity = cached
|
||||
print(f" ✓ 缓存命中! (相似度={similarity:.2f})")
|
||||
print(f" → 直接返回缓存答案")
|
||||
else:
|
||||
print(f" ✗ 缓存未命中")
|
||||
print(f" → 调用RAG服务...")
|
||||
|
||||
# 2. 调用RAG
|
||||
answer = await rag_service.query(question)
|
||||
print(f" → RAG返回答案")
|
||||
|
||||
# 3. 存储到缓存
|
||||
await cache.store("test_integration", question, answer, redis)
|
||||
print(f" → 已存入缓存")
|
||||
|
||||
print(f" 答案预览: {answer[:40]}...")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "╔" + "═" * 58 + "╗")
|
||||
print("║" + " " * 15 + "语义缓存功能测试" + " " * 20 + "║")
|
||||
print("╚" + "═" * 58 + "╝")
|
||||
|
||||
try:
|
||||
await test_basic_functionality()
|
||||
await test_semantic_matching()
|
||||
await test_performance()
|
||||
await test_cache_statistics()
|
||||
await test_integration_with_ragflow()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ 所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
# 清理测试数据
|
||||
redis = await get_redis()
|
||||
cache = SemanticCacheService(redis)
|
||||
await cache.clear(chat_id="test_chat_001")
|
||||
await cache.clear(chat_id="test_chat_002")
|
||||
await cache.clear(chat_id="test_chat_003")
|
||||
await cache.clear(chat_id="test_chat_stats")
|
||||
await cache.clear(chat_id="test_integration")
|
||||
print("\n测试数据已清理")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -19,7 +19,7 @@ AUTH_TOKEN = os.environ.get(
|
||||
"Bearer ",
|
||||
)
|
||||
DEFAULT_CHAT_ID = os.environ.get(
|
||||
"RAGFLOW_CHAT_ID", "db4bb966895b11f08cda0242ac130006"
|
||||
"RAGFLOW_CHAT_ID", "a455abcadbcb11f0884b0242ac130006"
|
||||
)
|
||||
DEFAULT_SESSION_ID = os.environ.get(
|
||||
"RAGFLOW_SESSION_ID", "38d765e48a3811f0be310242ac130006"
|
||||
|
||||
472
ruoyi-fastapi-backend/utils/semantic_cache_service.py
Normal file
472
ruoyi-fastapi-backend/utils/semantic_cache_service.py
Normal file
@ -0,0 +1,472 @@
|
||||
"""
|
||||
语义缓存服务 - 基于问答历史的缓存方案
|
||||
|
||||
功能:
|
||||
1. 记录用户通过RAG返回的对话历史(问题+答案)
|
||||
2. 当用户后续提问时,基于语义先搜索缓存
|
||||
3. 如果命中缓存,直接返回答案
|
||||
4. 如果未命中,调用RAG,并将结果存入缓存
|
||||
|
||||
技术方案:
|
||||
- 使用Redis存储问答对(问题文本 + 答案文本)
|
||||
- 使用简单的文本相似度算法(因为没有现成的Embedding模型)
|
||||
- 缓存键:rag:cache:{chat_id}:{question_hash}
|
||||
- 支持多级匹配:精确匹配 → 模糊匹配
|
||||
|
||||
作者:AI Assistant
|
||||
日期:2024
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, Tuple, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目"""
|
||||
question: str
|
||||
answer: str
|
||||
created_at: float
|
||||
hit_count: int = 0
|
||||
chat_id: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'CacheEntry':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class SemanticCacheService:
|
||||
"""
|
||||
语义缓存服务
|
||||
|
||||
支持三种匹配策略:
|
||||
1. 精确匹配:问题完全相同
|
||||
2. 关键语义匹配:提取问题核心关键词,判断语义相似
|
||||
3. 模糊匹配:计算文本相似度
|
||||
"""
|
||||
|
||||
# 缓存配置
|
||||
CACHE_PREFIX = "rag:semantic:cache"
|
||||
MAX_CACHE_SIZE = 1000 # 每个chat_id最大缓存条数
|
||||
SIMILARITY_THRESHOLD = 0.75 # 相似度阈值
|
||||
CACHE_TTL_HOURS = 24 # 缓存过期时间(小时)
|
||||
|
||||
def __init__(self, redis_client=None):
|
||||
"""
|
||||
初始化语义缓存服务
|
||||
|
||||
Args:
|
||||
redis_client: Redis客户端实例,如果为None则使用全局redis
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self._default_redis = None
|
||||
|
||||
def _get_redis(self):
|
||||
"""获取Redis客户端"""
|
||||
if self.redis is not None:
|
||||
return self.redis
|
||||
if self._default_redis is None:
|
||||
from config.get_redis import RedisUtil
|
||||
# 注意:这里需要异步获取,实际使用时应该传入redis实例
|
||||
logger.warning("未提供Redis客户端,语义缓存功能不可用")
|
||||
return self._default_redis
|
||||
|
||||
def _build_cache_key(self, chat_id: str, question_hash: str) -> str:
|
||||
"""构建缓存键"""
|
||||
return f"{self.CACHE_PREFIX}:{chat_id}:{question_hash}"
|
||||
|
||||
def _hash_question(self, question: str) -> str:
|
||||
"""对问题进行哈希,生成唯一标识"""
|
||||
# 标准化问题文本(去除多余空格、统一标点)
|
||||
normalized = self._normalize_question(question)
|
||||
return hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
def _normalize_question(self, question: str) -> str:
|
||||
"""标准化问题文本"""
|
||||
# 去除首尾空格
|
||||
q = question.strip()
|
||||
# 统一空格
|
||||
q = re.sub(r'\s+', ' ', q)
|
||||
# 统一中文标点
|
||||
q = q.replace('?', '?').replace(',', ',').replace('。', '.')
|
||||
return q
|
||||
|
||||
def _extract_keywords(self, question: str) -> List[str]:
|
||||
"""
|
||||
提取问题关键词(用于语义匹配)
|
||||
|
||||
策略:
|
||||
1. 保留完整问题作为主要匹配依据
|
||||
2. 提取名词、动词等核心词汇
|
||||
3. 过滤停用词
|
||||
"""
|
||||
# 简单停用词列表
|
||||
stopwords = {'的', '是', '了', '在', '有', '和', '与', '或', '吗', '呢', '吧', '啊', '哦', '请问', '能不能', '可以', '怎么', '如何', '什么', '多少', '几个'}
|
||||
|
||||
# 分词(简单按字符或词语切分)
|
||||
words = re.findall(r'[\w\u4e00-\u9fff]+', question.lower())
|
||||
|
||||
# 过滤停用词和过短的词
|
||||
keywords = [w for w in words if w not in stopwords and len(w) > 1]
|
||||
|
||||
return keywords
|
||||
|
||||
def _calculate_text_similarity(self, q1: str, q2: str) -> float:
|
||||
"""
|
||||
计算两个问题的文本相似度
|
||||
|
||||
使用Jaccard相似度 + 关键词匹配的综合策略
|
||||
"""
|
||||
# 标准化
|
||||
n1 = self._normalize_question(q1)
|
||||
n2 = self._normalize_question(q2)
|
||||
|
||||
# 如果完全相同,直接返回1.0
|
||||
if n1 == n2:
|
||||
return 1.0
|
||||
|
||||
# 计算关键词重叠度
|
||||
kw1 = set(self._extract_keywords(q1))
|
||||
kw2 = set(self._extract_keywords(q2))
|
||||
|
||||
if not kw1 or not kw2:
|
||||
return 0.0
|
||||
|
||||
# Jaccard相似度
|
||||
intersection = len(kw1 & kw2)
|
||||
union = len(kw1 | kw2)
|
||||
jaccard_sim = intersection / union if union > 0 else 0
|
||||
|
||||
# 长度相似度(惩罚长度差异过大的问题)
|
||||
len_sim = 1 - abs(len(n1) - len(n2)) / max(len(n1), len(n2))
|
||||
len_sim = max(0, len_sim) # 确保非负
|
||||
|
||||
# 综合相似度(关键词权重0.7,长度权重0.3)
|
||||
similarity = 0.7 * jaccard_sim + 0.3 * len_sim
|
||||
|
||||
return similarity
|
||||
|
||||
async def lookup(
|
||||
self,
|
||||
chat_id: str,
|
||||
question: str,
|
||||
redis_client=None
|
||||
) -> Optional[Tuple[str, float]]:
|
||||
"""
|
||||
在缓存中查找相似问题
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话ID
|
||||
question: 用户问题
|
||||
redis_client: Redis客户端(可选)
|
||||
|
||||
Returns:
|
||||
如果命中缓存,返回 (答案, 相似度)
|
||||
如果未命中,返回 None
|
||||
"""
|
||||
redis = redis_client or self._get_redis()
|
||||
if redis is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 精确匹配
|
||||
question_hash = self._hash_question(question)
|
||||
exact_key = self._build_cache_key(chat_id, question_hash)
|
||||
|
||||
cached_data = await redis.get(exact_key)
|
||||
if cached_data:
|
||||
entry = CacheEntry.from_dict(json.loads(cached_data))
|
||||
entry.hit_count += 1
|
||||
# 更新命中次数(异步,不阻塞返回)
|
||||
await redis.set(exact_key, json.dumps(entry.to_dict()), ex=self.CACHE_TTL_HOURS * 3600)
|
||||
logger.info(f"[SemanticCache] 精确命中: {question[:30]}...")
|
||||
return entry.answer, 1.0
|
||||
|
||||
# 2. 语义模糊匹配(扫描同chat_id的所有缓存)
|
||||
# 获取该chat_id的所有缓存键
|
||||
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
|
||||
cache_keys = await redis.keys(pattern)
|
||||
|
||||
if not cache_keys:
|
||||
return None
|
||||
|
||||
# 遍历缓存,计算相似度
|
||||
best_match = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for key in cache_keys:
|
||||
# 跳过刚检查过的精确匹配键
|
||||
if key == exact_key:
|
||||
continue
|
||||
|
||||
cached_data = await redis.get(key)
|
||||
if not cached_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
entry = CacheEntry.from_dict(json.loads(cached_data))
|
||||
similarity = self._calculate_text_similarity(question, entry.question)
|
||||
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = entry
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"[SemanticCache] 解析缓存失败: {key}, {e}")
|
||||
continue
|
||||
|
||||
# 判断是否达到相似度阈值
|
||||
if best_match and best_similarity >= self.SIMILARITY_THRESHOLD:
|
||||
logger.info(f"[SemanticCache] 语义命中 (相似度={best_similarity:.2f}): {question[:30]}...")
|
||||
# 更新命中次数
|
||||
best_match.hit_count += 1
|
||||
await redis.set(key, json.dumps(best_match.to_dict()), ex=self.CACHE_TTL_HOURS * 3600)
|
||||
return best_match.answer, best_similarity
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SemanticCache] 查找失败: {e}")
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
chat_id: str,
|
||||
question: str,
|
||||
answer: str,
|
||||
redis_client=None
|
||||
) -> bool:
|
||||
"""
|
||||
将问答对存入缓存
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话ID
|
||||
question: 用户问题
|
||||
answer: RAG返回的答案
|
||||
redis_client: Redis客户端(可选)
|
||||
|
||||
Returns:
|
||||
是否存储成功
|
||||
"""
|
||||
redis = redis_client or self._get_redis()
|
||||
if redis is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 过滤无效答案
|
||||
if not answer or len(answer.strip()) < 10:
|
||||
logger.info(f"[SemanticCache] 答案太短,不缓存: {answer[:20]}...")
|
||||
return False
|
||||
|
||||
# 构建缓存条目
|
||||
question_hash = self._hash_question(question)
|
||||
cache_key = self._build_cache_key(chat_id, question_hash)
|
||||
|
||||
entry = CacheEntry(
|
||||
question=question,
|
||||
answer=answer,
|
||||
created_at=time.time(),
|
||||
hit_count=0,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
# 检查是否已存在(避免重复存储)
|
||||
existing = await redis.get(cache_key)
|
||||
if existing:
|
||||
logger.debug(f"[SemanticCache] 问题已存在,跳过: {question[:30]}...")
|
||||
return True
|
||||
|
||||
# 清理旧缓存(如果该chat_id缓存过多)
|
||||
await self._cleanup_old_cache(chat_id, redis)
|
||||
|
||||
# 存储缓存
|
||||
await redis.set(
|
||||
cache_key,
|
||||
json.dumps(entry.to_dict()),
|
||||
ex=self.CACHE_TTL_HOURS * 3600
|
||||
)
|
||||
|
||||
logger.info(f"[SemanticCache] 已缓存: {question[:30]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SemanticCache] 存储失败: {e}")
|
||||
return False
|
||||
|
||||
async def _cleanup_old_cache(self, chat_id: str, redis_client=None) -> int:
|
||||
"""
|
||||
清理旧缓存(当缓存过多时)
|
||||
|
||||
策略:删除最久未命中(创建时间最早)的缓存
|
||||
|
||||
Returns:
|
||||
删除的缓存数量
|
||||
"""
|
||||
redis = redis_client or self._get_redis()
|
||||
if redis is None:
|
||||
return 0
|
||||
|
||||
try:
|
||||
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
|
||||
cache_keys = await redis.keys(pattern)
|
||||
|
||||
if len(cache_keys) <= self.MAX_CACHE_SIZE:
|
||||
return 0
|
||||
|
||||
# 获取所有缓存的创建时间
|
||||
cache_info = []
|
||||
for key in cache_keys:
|
||||
data = await redis.get(key)
|
||||
if data:
|
||||
try:
|
||||
entry = CacheEntry.from_dict(json.loads(data))
|
||||
cache_info.append((key, entry.created_at))
|
||||
except:
|
||||
pass
|
||||
|
||||
# 按创建时间排序,删除最旧的
|
||||
cache_info.sort(key=lambda x: x[1])
|
||||
delete_count = len(cache_keys) - self.MAX_CACHE_SIZE
|
||||
|
||||
deleted = 0
|
||||
for key, _ in cache_info[:delete_count]:
|
||||
await redis.delete(key)
|
||||
deleted += 1
|
||||
|
||||
if deleted > 0:
|
||||
logger.info(f"[SemanticCache] 清理了 {deleted} 条旧缓存")
|
||||
|
||||
return deleted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SemanticCache] 清理失败: {e}")
|
||||
return 0
|
||||
|
||||
async def clear(self, chat_id: str = None, redis_client=None) -> bool:
|
||||
"""
|
||||
清除缓存
|
||||
|
||||
Args:
|
||||
chat_id: 要清除的chat_id,如果为None则清除所有
|
||||
redis_client: Redis客户端(可选)
|
||||
|
||||
Returns:
|
||||
是否清除成功
|
||||
"""
|
||||
redis = redis_client or self._get_redis()
|
||||
if redis is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
if chat_id:
|
||||
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
|
||||
else:
|
||||
pattern = f"{self.CACHE_PREFIX}:*"
|
||||
|
||||
cache_keys = await redis.keys(pattern)
|
||||
|
||||
if cache_keys:
|
||||
await redis.delete(*cache_keys)
|
||||
logger.info(f"[SemanticCache] 清除 {len(cache_keys)} 条缓存")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SemanticCache] 清除失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_stats(self, chat_id: str = None, redis_client=None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Args:
|
||||
chat_id: 要统计的chat_id,如果为None则统计所有
|
||||
redis_client: Redis客户端(可选)
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
redis = redis_client or self._get_redis()
|
||||
if redis is None:
|
||||
return {}
|
||||
|
||||
try:
|
||||
if chat_id:
|
||||
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
|
||||
else:
|
||||
pattern = f"{self.CACHE_PREFIX}:*"
|
||||
|
||||
cache_keys = await redis.keys(pattern)
|
||||
|
||||
stats = {
|
||||
"total_entries": len(cache_keys),
|
||||
"chat_id": chat_id or "all",
|
||||
"config": {
|
||||
"max_cache_size": self.MAX_CACHE_SIZE,
|
||||
"similarity_threshold": self.SIMILARITY_THRESHOLD,
|
||||
"cache_ttl_hours": self.CACHE_TTL_HOURS
|
||||
}
|
||||
}
|
||||
|
||||
# 统计命中次数分布
|
||||
if cache_keys:
|
||||
total_hits = 0
|
||||
for key in cache_keys[:100]: # 只采样前100个
|
||||
data = await redis.get(key)
|
||||
if data:
|
||||
try:
|
||||
entry = CacheEntry.from_dict(json.loads(data))
|
||||
total_hits += entry.hit_count
|
||||
except:
|
||||
pass
|
||||
|
||||
stats["total_hits"] = total_hits
|
||||
stats["avg_hits_per_entry"] = total_hits / min(len(cache_keys), 100)
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SemanticCache] 统计失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_semantic_cache_service: Optional[SemanticCacheService] = None
|
||||
|
||||
|
||||
def get_semantic_cache_service() -> SemanticCacheService:
|
||||
"""获取语义缓存服务单例"""
|
||||
global _semantic_cache_service
|
||||
if _semantic_cache_service is None:
|
||||
_semantic_cache_service = SemanticCacheService()
|
||||
return _semantic_cache_service
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def lookup_question(chat_id: str, question: str, redis_client=None) -> Optional[Tuple[str, float]]:
|
||||
"""查找缓存问题"""
|
||||
service = get_semantic_cache_service()
|
||||
return await service.lookup(chat_id, question, redis_client)
|
||||
|
||||
|
||||
async def store_qa_pair(chat_id: str, question: str, answer: str, redis_client=None) -> bool:
|
||||
"""存储问答对"""
|
||||
service = get_semantic_cache_service()
|
||||
return await service.store(chat_id, question, answer, redis_client)
|
||||
|
||||
|
||||
async def clear_cache(chat_id: str = None, redis_client=None) -> bool:
|
||||
"""清除缓存"""
|
||||
service = get_semantic_cache_service()
|
||||
return await service.clear(chat_id, redis_client)
|
||||
Loading…
Reference in New Issue
Block a user