增加问答缓存功能;增加对 maxkb 的测试

This commit is contained in:
Tian jianyong 2025-12-24 13:05:20 +08:00
parent 10ca1e2c00
commit 5d15567460
6 changed files with 1222 additions and 7 deletions

View File

@ -1,11 +1,12 @@
"""
简化的RAGFlow控制器 - 移除异步复杂性
简化的RAGFlow控制器 - 集成语义缓存
"""
import asyncio
import hashlib
import json
import time
from typing import Any, Dict, Generator
from typing import Any, Dict, Generator, Optional, Tuple
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
@ -21,6 +22,18 @@ from module_admin.entity.vo.ragflow_vo import (
)
from utils.log_util import logger
from utils.response_util import ResponseUtil
from utils.semantic_cache_service import get_semantic_cache_service, lookup_question, store_qa_pair
async def _async_store_qa(chat_id: str, question: str, answer: str, redis) -> None:
"""
异步存储问答对到语义缓存
"""
try:
await store_qa_pair(chat_id, question, answer, redis)
except Exception as e:
logger.warning(f"[SemanticCache] 异步存储失败: {e}")
# 使用标准的APIRouter增加认证依赖
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
@ -55,7 +68,7 @@ async def converse_with_chat_assistant(
converse_params: ConverseWithChatAssistantModel, # 使用正确的Pydantic模型
):
"""
与聊天助手进行对话 - 同步版本简化复杂性
与聊天助手进行对话 - 集成语义缓存版本
"""
start_time = time.perf_counter()
@ -63,6 +76,19 @@ async def converse_with_chat_assistant(
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 语义缓存查找(非流式对话)
if not converse_params.stream and redis:
cache_result = await lookup_question(
converse_params.chat_id,
converse_params.question,
redis
)
if cache_result:
cached_answer, similarity = cache_result
logger.info(f'[SemanticCache] 命中缓存 (相似度={similarity:.2f}): chat_id={converse_params.chat_id}')
# 返回缓存的答案,包装成标准响应格式
return ResponseUtil.success(data={'answer': cached_answer, 'from_cache': True, 'similarity': similarity})
# 检查是否应该使用搜索服务
try:
intent = await SearchService.classify_intent(converse_params.question)
@ -73,7 +99,7 @@ async def converse_with_chat_assistant(
except Exception as e:
logger.warning(f"意图分类失败使用RAG服务: {e}")
# 缓存检查
# 缓存检查(原有的精确缓存)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
try:
@ -104,7 +130,23 @@ async def converse_with_chat_assistant(
# 非流式响应
response = parse_result(result)
# 设置缓存
# 设置语义缓存(存储问答对)
if not converse_params.stream and redis and isinstance(result, dict) and result.get('code') == 0:
try:
answer_data = result.get('data', {})
answer_text = answer_data.get('answer') if isinstance(answer_data, dict) else str(answer_data)
if answer_text and len(answer_text.strip()) >= 10:
# 异步存储,不阻塞响应返回
asyncio.create_task(_async_store_qa(
converse_params.chat_id,
converse_params.question,
answer_text,
redis
))
except Exception as e:
logger.warning(f"语义缓存存储失败: {e}")
# 设置原有精确缓存
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
try:
await redis.set(cache_key, json.dumps(result.get('data', {}), ensure_ascii=False), ex=60)

View File

@ -136,7 +136,8 @@ class ConverseWithChatAssistantModel(BaseModel):
会话聊天模型
"""
model_config = ConfigDict(alias_generator=to_camel, from_attributes=True)
# 移除alias_generator使用原始的snake_case参数名
model_config = ConfigDict(from_attributes=True)
chat_id: str = Field(default = None, description='会话ID')
question: str = Field(default = None, description='问题')

View File

@ -0,0 +1,399 @@
"""
MaxKB RAG 性能测试 - 测试首个token接收时间
"""
import asyncio
import aiohttp
import time
import sys
import os
import json
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# MaxKB 配置
MAXKB_API_URL = "http://10.0.0.202:8080"
API_KEY = "application-b758e71ecb8d4237f91795365b433c35"
CSRF_TOKEN = "bDzz45UzuDGB33YCykegwEG9QMBuLba3cisIDsL5Z4tsXTrntxZq8nhVJf7qT5Pq"
# DeepSeek 配置
DEEPSEEK_API_BASE = "https://api.deepseek.com"
DEEPSEEK_API_KEY = "sk-56b608b26a6949e4b09b5bf5f11c8f5b"
DEEPSEEK_MODEL = "deepseek-chat"
# 测试问题
TEST_QUESTION = "你好,请介绍一下你自己"
async def test_deepseek_first_token(question: str = TEST_QUESTION):
"""测试 DeepSeek 直接调用的首个token接收时间"""
print("=" * 70)
print("DeepSeek API 性能测试 - 首个Token接收时间")
print("=" * 70)
url = f"{DEEPSEEK_API_BASE}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {DEEPSEEK_API_KEY}"
}
data = {
"model": DEEPSEEK_MODEL,
"messages": [
{"role": "user", "content": question}
],
"stream": True
}
print(f"\n[1/2] 发送测试问题: \"{question}\"")
print(" 等待响应...")
total_start_time = time.time()
first_token_time = None
token_count = 0
first_token_received = False
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status != 200:
print(f" ✗ 请求失败: {response.status}")
error_text = await response.text()
print(f" 错误信息: {error_text}")
return
print(f" ✓ 连接建立,开始接收数据...")
print(f"\n[2/2] 接收流式数据...")
print("-" * 70)
async for chunk in response.content:
current_time = time.time()
elapsed = current_time - total_start_time
# 解析 chunk 数据
chunk_str = chunk.decode('utf-8')
if not chunk_str.strip():
continue
# 记录首个token时间
if not first_token_received and 'data:' in chunk_str:
first_token_time = elapsed
first_token_received = True
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f}")
print("-" * 70)
# 解析并显示内容
for line in chunk_str.split('\n'):
if line.startswith('data:'):
data_str = line[5:].strip()
if data_str == '[DONE]':
continue
try:
data_json = json.loads(data_str)
# 从 choices[0].delta.content 获取实际内容DeepSeek响应格式
choices = data_json.get('choices', [])
if choices:
delta = choices[0].get('delta', {})
content = delta.get('content', '')
if content:
token_count += 1
# 显示内容前40个字符
display_content = content[:40] + '...' if len(content) > 40 else content
print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}")
except json.JSONDecodeError:
pass
total_time = time.time() - total_start_time
# 输出测试结果
print("\n" + "=" * 70)
print("📊 DeepSeek API 性能测试结果")
print("=" * 70)
print(f" 测试问题: {question}")
print(f" 首个Token时间: {first_token_time:.4f}" if first_token_time else " 首个Token时间: N/A")
print(f" 总响应时间: {total_time:.4f}")
print(f" 接收Token数: {token_count}")
if first_token_time and total_time > 0:
print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒")
print("=" * 70)
return first_token_time, total_time
async def get_chat_id():
"""获取一个新的 chat_id"""
url = f"{MAXKB_API_URL}/chat/api/open"
headers = {
"accept": "*/*",
"Authorization": f"Bearer {API_KEY}"
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
# data 直接就是 chat_id 字符串
return data.get("data")
else:
raise Exception(f"获取 chat_id 失败: {response.status}")
async def test_first_token_time(question: str = TEST_QUESTION):
"""测试首个token的接收时间"""
print("=" * 70)
print("MaxKB RAG 性能测试 - 首个Token接收时间")
print("=" * 70)
# 1. 获取 chat_id
print("\n[1/3] 获取 chat_id...")
chat_id_start_time = time.time()
try:
chat_id = await get_chat_id()
chat_id_time = time.time() - chat_id_start_time
print(f" ✓ chat_id: {chat_id} ({chat_id_time:.4f}s)")
except Exception as e:
print(f" ✗ 获取 chat_id 失败: {e}")
return
# 2. 发起流式对话请求
print(f"\n[2/3] 发送测试问题: \"{question}\"")
print(" 等待响应...")
url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}"
headers = {
"accept": "*/*",
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"X-CSRFTOKEN": CSRF_TOKEN
}
data = {
"message": question,
"stream": True,
"re_chat": False
}
total_start_time = time.time()
first_token_time = None
token_count = 0
first_token_received = False
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status != 200:
print(f" ✗ 请求失败: {response.status}")
error_text = await response.text()
print(f" 错误信息: {error_text}")
return
print(f" ✓ 连接建立,开始接收数据...")
print(f"\n[3/3] 接收流式数据...")
print("-" * 70)
async for chunk in response.content:
current_time = time.time()
elapsed = current_time - total_start_time
# 解析 chunk 数据
chunk_str = chunk.decode('utf-8')
if not chunk_str.strip():
continue
# 记录首个token时间
if not first_token_received and chunk_str.startswith('data:'):
first_token_time = elapsed
first_token_received = True
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f}")
print("-" * 70)
# 解析并显示内容
for line in chunk_str.split('\n'):
if line.startswith('data:'):
data_str = line[5:].strip()
if data_str == '[DONE]':
continue
try:
import json
data_json = json.loads(data_str)
# 从 content 字段获取实际内容MaxKB响应格式
content = data_json.get('content', '')
if content:
if not first_token_received:
first_token_time = elapsed
first_token_received = True
print(f"\n🎯 首个Token接收时间: {first_token_time:.4f}")
print("-" * 70)
token_count += 1
# 显示内容前40个字符
display_content = content[:40] + '...' if len(content) > 40 else content
print(f"[{elapsed:.3f}s] Token #{token_count}: {display_content}")
except json.JSONDecodeError:
pass
total_time = time.time() - total_start_time
# 3. 输出测试结果
print("\n" + "=" * 70)
print("📊 性能测试结果")
print("=" * 70)
print(f" 测试问题: {question}")
print(f" 首个Token时间: {first_token_time:.4f}" if first_token_time else " 首个Token时间: N/A")
print(f" 总响应时间: {total_time:.4f}")
print(f" 接收Token数: {token_count}")
if first_token_time and total_time > 0:
print(f" 平均Token速率: {token_count / (total_time - first_token_time):.2f} tokens/秒")
print("=" * 70)
return first_token_time, total_time
async def run_multiple_tests(n: int = 3, question: str = None):
"""运行多次测试取平均值"""
if question is None:
question = TEST_QUESTION
print(f"\n🔄 准备运行 {n} 次性能测试...")
print(f" 测试问题: {question}")
print()
first_token_times = []
total_times = []
for i in range(n):
print(f"\n{'='*70}")
print(f"测试 #{i + 1}/{n}")
print('='*70)
# 重新获取 chat_id每次测试使用新的会话
chat_id_start_time = time.time()
chat_id = await get_chat_id()
chat_id_time = time.time() - chat_id_start_time
print(f" 获取chat_id: {chat_id_time:.4f}s")
url = f"{MAXKB_API_URL}/chat/api/chat_message/{chat_id}"
headers = {
"accept": "*/*",
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json",
"X-CSRFTOKEN": CSRF_TOKEN
}
data = {
"message": question,
"stream": True,
"re_chat": False
}
total_start_time = time.time()
first_token_time = None
token_count = 0
first_token_received = False
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status != 200:
print(f"请求失败: {response.status}")
continue
async for chunk in response.content:
current_time = time.time()
elapsed = current_time - total_start_time
chunk_str = chunk.decode('utf-8')
if not chunk_str.strip():
continue
if not first_token_received and chunk_str.startswith('data:'):
first_token_time = elapsed
first_token_received = True
if chunk_str.startswith('data:'):
data_str = chunk_str[5:].strip()
if data_str == '[DONE]':
continue
try:
import json
data_json = json.loads(data_str)
# 从 content 字段获取实际内容
if data_json.get('content'):
if not first_token_received:
first_token_time = elapsed
first_token_received = True
token_count += 1
except:
pass
total_time = time.time() - total_start_time
if first_token_time:
first_token_times.append(first_token_time)
total_times.append(total_time)
print(f"\n 结果 #{i + 1}: 首Token={first_token_time:.4f}s, 总时间={total_time:.4f}s, Tokens={token_count}")
await asyncio.sleep(1) # 等待1秒再进行下次测试
# 计算平均值
print(f"\n{'='*70}")
print("📈 多次测试统计结果")
print('='*70)
if first_token_times:
avg_first_token = sum(first_token_times) / len(first_token_times)
min_first_token = min(first_token_times)
max_first_token = max(first_token_times)
print(f" 首个Token时间:")
print(f" - 平均值: {avg_first_token:.4f}")
print(f" - 最小值: {min_first_token:.4f}")
print(f" - 最大值: {max_first_token:.4f}")
if total_times:
avg_total = sum(total_times) / len(total_times)
print(f" 总响应时间:")
print(f" - 平均值: {avg_total:.4f}")
print('='*70)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='MaxKB RAG 性能测试')
parser.add_argument('question', nargs='?', default=TEST_QUESTION, help='测试问题(直接写在命令后面)')
parser.add_argument('--times', type=int, default=1, help='测试次数 (默认1次)')
parser.add_argument('--deepseek', action='store_true', help='同时测试 DeepSeek API 性能')
args = parser.parse_args()
if args.deepseek:
# 先测试 DeepSeek
print("\n" + "=" * 70)
print("🆚 性能对比测试")
print("=" * 70)
print("\n>>> 开始测试 DeepSeek API...")
ds_first, ds_total = asyncio.run(test_deepseek_first_token(args.question))
print("\n\n>>> 开始测试 MaxKB...")
maxkb_first, maxkb_total = asyncio.run(test_first_token_time(args.question))
# 对比结果
print("\n" + "=" * 70)
print("📊 性能对比结果")
print("=" * 70)
print(f"\n {'服务':<15} {'首个Token':<15} {'总响应时间':<15}")
print(f" {'-'*45}")
print(f" {'DeepSeek':<15} {f'{ds_first:.4f}s' if ds_first else 'N/A':<15} {f'{ds_total:.4f}s' if ds_total else 'N/A':<15}")
print(f" {'MaxKB':<15} {f'{maxkb_first:.4f}s' if maxkb_first else 'N/A':<15} {f'{maxkb_total:.4f}s' if maxkb_total else 'N/A':<15}")
print(f" {'-'*45}")
if ds_first and maxkb_first:
diff = maxkb_first - ds_first
print(f"\n ⚡ MaxKB 比 DeepSeek 慢: {diff:.4f} 秒 ({diff/ds_first*100:.1f}%)")
print("=" * 70)
elif args.times == 1:
asyncio.run(test_first_token_time(args.question))
else:
asyncio.run(run_multiple_tests(args.times, args.question))

View File

@ -0,0 +1,301 @@
"""
测试语义缓存功能
功能
1. 测试缓存的存储和查找
2. 测试语义匹配效果
3. 验证缓存性能
使用方式
python test_semantic_cache.py
作者AI Assistant
"""
import asyncio
import sys
import os
import time
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.semantic_cache_service import (
SemanticCacheService,
get_semantic_cache_service,
lookup_question,
store_qa_pair,
clear_cache
)
from config.get_redis import RedisUtil
# 全局Redis客户端
_redis_client = None
async def get_redis():
"""获取Redis客户端"""
global _redis_client
if _redis_client is None:
_redis_client = await RedisUtil.create_redis_pool()
return _redis_client
async def test_basic_functionality():
"""测试基本功能"""
print("\n" + "=" * 60)
print("测试1: 基本功能验证")
print("=" * 60)
# 获取Redis客户端
redis = await get_redis()
# 创建缓存服务实例
cache = SemanticCacheService(redis)
# 清理测试缓存
await cache.clear(chat_id="test_chat_001")
# 测试1: 精确匹配
print("\n1. 测试精确匹配...")
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
print(f" 首次查找(未命中): {result}")
assert result is None, "首次查找应该未命中"
# 存储问答对
success = await cache.store(
"test_chat_001",
"康达公司成立于哪一年?",
"康达公司成立于2012年是一家专注于智能机器人研发的高科技企业。",
redis
)
print(f" 存储问答对: {'成功' if success else '失败'}")
# 再次查找
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
print(f" 精确查找(已缓存): {'命中' if result else '未命中'}")
if result:
print(f" 答案: {result[0][:50]}...")
print(f" 相似度: {result[1]:.2f}")
assert result[1] == 1.0, "精确匹配相似度应该为1.0"
print(" ✓ 精确匹配测试通过")
async def test_semantic_matching():
"""测试语义匹配"""
print("\n" + "=" * 60)
print("测试2: 语义匹配验证")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并重新测试
await cache.clear(chat_id="test_chat_002")
# 存储原始问答
original_question = "公司的退货政策是什么?"
original_answer = "我们支持7天无理由退换货您可以在收到商品后7天内申请退换。"
await cache.store("test_chat_002", original_question, original_answer, redis)
# 测试各种问法
test_cases = [
("公司的退货政策是什么?", "原问题"),
("退货政策是怎么样的?", "同义词替换"),
("如果我想退货可以吗?", "语义相关"),
("你们怎么退货?", "口语化表达"),
("产品质量有问题怎么换?", "相关但不同"),
("今天天气怎么样?", "完全无关"),
]
print("\n语义匹配测试结果:")
print("-" * 60)
for question, desc in test_cases:
result = await cache.lookup("test_chat_002", question, redis)
if result:
answer, similarity = result
print(f" [{desc}]")
print(f" Q: {question[:30]}...")
print(f" 命中! 相似度={similarity:.2f}")
else:
print(f" [{desc}]")
print(f" Q: {question[:30]}...")
print(f" 未命中")
async def test_performance():
"""测试缓存性能"""
print("\n" + "=" * 60)
print("测试3: 缓存性能验证")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并准备测试数据
await cache.clear(chat_id="test_chat_003")
# 预热缓存
questions = [
"你们公司提供哪些服务?",
"如何联系客服?",
"产品售后政策是什么?",
"发货需要多长时间?",
"支持哪些支付方式?",
]
print("\n预热缓存...")
start_time = time.time()
for i, q in enumerate(questions):
await cache.store(
"test_chat_003",
q,
f"这是针对问题的回答 #{i+1}",
redis
)
warm_time = time.time() - start_time
print(f" 存储5条问答耗时: {warm_time*1000:.2f}ms")
# 性能测试:精确匹配 vs 语义匹配
print("\n性能测试:")
# 精确匹配测试
start_time = time.time()
for _ in range(100):
await cache.lookup("test_chat_003", "你们公司提供哪些服务?", redis)
exact_time = time.time() - start_time
print(f" 100次精确匹配: {exact_time*1000:.2f}ms (平均{exact_time*10:.2f}ms/次)")
# 语义匹配测试(会扫描所有缓存)
start_time = time.time()
for _ in range(10):
await cache.lookup("test_chat_003", "公司提供什么服务?", redis)
semantic_time = time.time() - start_time
print(f" 10次语义匹配: {semantic_time*1000:.2f}ms (平均{semantic_time*100:.2f}ms/次)")
async def test_cache_statistics():
"""测试缓存统计"""
print("\n" + "=" * 60)
print("测试4: 缓存统计功能")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并准备数据
await cache.clear(chat_id="test_chat_stats")
# 存储一些问答
await cache.store("test_chat_stats", "问题1", "答案1", redis)
await cache.store("test_chat_stats", "问题2", "答案2", redis)
# 命中几次
await cache.lookup("test_chat_stats", "问题1", redis)
await cache.lookup("test_chat_stats", "问题1", redis)
await cache.lookup("test_chat_stats", "问题2", redis)
# 获取统计
stats = await cache.get_stats("test_chat_stats", redis)
print(f"\n缓存统计信息:")
print(f" 总条目数: {stats.get('total_entries', 0)}")
print(f" 配置信息:")
print(f" - 最大缓存大小: {stats['config']['max_cache_size']}")
print(f" - 相似度阈值: {stats['config']['similarity_threshold']}")
print(f" - 缓存过期时间: {stats['config']['cache_ttl_hours']}小时")
async def test_integration_with_ragflow():
"""集成测试模拟RAG流程"""
print("\n" + "=" * 60)
print("测试5: 集成测试 - 模拟RAG流程")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 模拟RAG服务
class MockRAGService:
async def query(self, question):
# 模拟RAG耗时
await asyncio.sleep(0.1)
return f"RAG生成的答案{question}的详细回答..."
rag_service = MockRAGService()
# 清理测试缓存
await cache.clear(chat_id="test_integration")
test_questions = [
"康达机器人有哪些产品?",
"康达机器人有哪些产品?", # 重复问题,期望命中缓存
"你们的机器人产品有哪些?", # 相似问题,期望语义匹配
"产品价格是多少?", # 新问题,不命中
]
print("\n模拟用户连续提问:")
print("-" * 60)
for i, question in enumerate(test_questions):
print(f"\n用户提问 #{i+1}: {question[:30]}...")
# 1. 先查缓存
cached = await cache.lookup("test_integration", question, redis)
if cached:
answer, similarity = cached
print(f" ✓ 缓存命中! (相似度={similarity:.2f})")
print(f" → 直接返回缓存答案")
else:
print(f" ✗ 缓存未命中")
print(f" → 调用RAG服务...")
# 2. 调用RAG
answer = await rag_service.query(question)
print(f" → RAG返回答案")
# 3. 存储到缓存
await cache.store("test_integration", question, answer, redis)
print(f" → 已存入缓存")
print(f" 答案预览: {answer[:40]}...")
async def main():
"""主测试函数"""
print("\n" + "" + "" * 58 + "")
print("" + " " * 15 + "语义缓存功能测试" + " " * 20 + "")
print("" + "" * 58 + "")
try:
await test_basic_functionality()
await test_semantic_matching()
await test_performance()
await test_cache_statistics()
await test_integration_with_ragflow()
print("\n" + "=" * 60)
print("✓ 所有测试完成!")
print("=" * 60)
except Exception as e:
print(f"\n✗ 测试失败: {e}")
import traceback
traceback.print_exc()
finally:
# 清理测试数据
redis = await get_redis()
cache = SemanticCacheService(redis)
await cache.clear(chat_id="test_chat_001")
await cache.clear(chat_id="test_chat_002")
await cache.clear(chat_id="test_chat_003")
await cache.clear(chat_id="test_chat_stats")
await cache.clear(chat_id="test_integration")
print("\n测试数据已清理")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -19,7 +19,7 @@ AUTH_TOKEN = os.environ.get(
"Bearer ",
)
DEFAULT_CHAT_ID = os.environ.get(
"RAGFLOW_CHAT_ID", "db4bb966895b11f08cda0242ac130006"
"RAGFLOW_CHAT_ID", "a455abcadbcb11f0884b0242ac130006"
)
DEFAULT_SESSION_ID = os.environ.get(
"RAGFLOW_SESSION_ID", "38d765e48a3811f0be310242ac130006"

View File

@ -0,0 +1,472 @@
"""
语义缓存服务 - 基于问答历史的缓存方案
功能
1. 记录用户通过RAG返回的对话历史问题+答案
2. 当用户后续提问时基于语义先搜索缓存
3. 如果命中缓存直接返回答案
4. 如果未命中调用RAG并将结果存入缓存
技术方案
- 使用Redis存储问答对问题文本 + 答案文本
- 使用简单的文本相似度算法因为没有现成的Embedding模型
- 缓存键rag:cache:{chat_id}:{question_hash}
- 支持多级匹配精确匹配 模糊匹配
作者AI Assistant
日期2024
"""
import hashlib
import json
import time
import re
import logging
from typing import Optional, Tuple, Dict, Any, List
from dataclasses import dataclass, asdict
# 配置日志
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""缓存条目"""
question: str
answer: str
created_at: float
hit_count: int = 0
chat_id: str = ""
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, data: dict) -> 'CacheEntry':
return cls(**data)
class SemanticCacheService:
"""
语义缓存服务
支持三种匹配策略
1. 精确匹配问题完全相同
2. 关键语义匹配提取问题核心关键词判断语义相似
3. 模糊匹配计算文本相似度
"""
# 缓存配置
CACHE_PREFIX = "rag:semantic:cache"
MAX_CACHE_SIZE = 1000 # 每个chat_id最大缓存条数
SIMILARITY_THRESHOLD = 0.75 # 相似度阈值
CACHE_TTL_HOURS = 24 # 缓存过期时间(小时)
def __init__(self, redis_client=None):
"""
初始化语义缓存服务
Args:
redis_client: Redis客户端实例如果为None则使用全局redis
"""
self.redis = redis_client
self._default_redis = None
def _get_redis(self):
"""获取Redis客户端"""
if self.redis is not None:
return self.redis
if self._default_redis is None:
from config.get_redis import RedisUtil
# 注意这里需要异步获取实际使用时应该传入redis实例
logger.warning("未提供Redis客户端语义缓存功能不可用")
return self._default_redis
def _build_cache_key(self, chat_id: str, question_hash: str) -> str:
"""构建缓存键"""
return f"{self.CACHE_PREFIX}:{chat_id}:{question_hash}"
def _hash_question(self, question: str) -> str:
"""对问题进行哈希,生成唯一标识"""
# 标准化问题文本(去除多余空格、统一标点)
normalized = self._normalize_question(question)
return hashlib.md5(normalized.encode('utf-8')).hexdigest()[:16]
def _normalize_question(self, question: str) -> str:
"""标准化问题文本"""
# 去除首尾空格
q = question.strip()
# 统一空格
q = re.sub(r'\s+', ' ', q)
# 统一中文标点
q = q.replace('', '?').replace('', ',').replace('', '.')
return q
def _extract_keywords(self, question: str) -> List[str]:
"""
提取问题关键词用于语义匹配
策略
1. 保留完整问题作为主要匹配依据
2. 提取名词动词等核心词汇
3. 过滤停用词
"""
# 简单停用词列表
stopwords = {'', '', '', '', '', '', '', '', '', '', '', '', '', '请问', '能不能', '可以', '怎么', '如何', '什么', '多少', '几个'}
# 分词(简单按字符或词语切分)
words = re.findall(r'[\w\u4e00-\u9fff]+', question.lower())
# 过滤停用词和过短的词
keywords = [w for w in words if w not in stopwords and len(w) > 1]
return keywords
def _calculate_text_similarity(self, q1: str, q2: str) -> float:
"""
计算两个问题的文本相似度
使用Jaccard相似度 + 关键词匹配的综合策略
"""
# 标准化
n1 = self._normalize_question(q1)
n2 = self._normalize_question(q2)
# 如果完全相同直接返回1.0
if n1 == n2:
return 1.0
# 计算关键词重叠度
kw1 = set(self._extract_keywords(q1))
kw2 = set(self._extract_keywords(q2))
if not kw1 or not kw2:
return 0.0
# Jaccard相似度
intersection = len(kw1 & kw2)
union = len(kw1 | kw2)
jaccard_sim = intersection / union if union > 0 else 0
# 长度相似度(惩罚长度差异过大的问题)
len_sim = 1 - abs(len(n1) - len(n2)) / max(len(n1), len(n2))
len_sim = max(0, len_sim) # 确保非负
# 综合相似度关键词权重0.7长度权重0.3
similarity = 0.7 * jaccard_sim + 0.3 * len_sim
return similarity
async def lookup(
self,
chat_id: str,
question: str,
redis_client=None
) -> Optional[Tuple[str, float]]:
"""
在缓存中查找相似问题
Args:
chat_id: 聊天会话ID
question: 用户问题
redis_client: Redis客户端可选
Returns:
如果命中缓存返回 (答案, 相似度)
如果未命中返回 None
"""
redis = redis_client or self._get_redis()
if redis is None:
return None
try:
# 1. 精确匹配
question_hash = self._hash_question(question)
exact_key = self._build_cache_key(chat_id, question_hash)
cached_data = await redis.get(exact_key)
if cached_data:
entry = CacheEntry.from_dict(json.loads(cached_data))
entry.hit_count += 1
# 更新命中次数(异步,不阻塞返回)
await redis.set(exact_key, json.dumps(entry.to_dict()), ex=self.CACHE_TTL_HOURS * 3600)
logger.info(f"[SemanticCache] 精确命中: {question[:30]}...")
return entry.answer, 1.0
# 2. 语义模糊匹配扫描同chat_id的所有缓存
# 获取该chat_id的所有缓存键
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
cache_keys = await redis.keys(pattern)
if not cache_keys:
return None
# 遍历缓存,计算相似度
best_match = None
best_similarity = 0.0
for key in cache_keys:
# 跳过刚检查过的精确匹配键
if key == exact_key:
continue
cached_data = await redis.get(key)
if not cached_data:
continue
try:
entry = CacheEntry.from_dict(json.loads(cached_data))
similarity = self._calculate_text_similarity(question, entry.question)
if similarity > best_similarity:
best_similarity = similarity
best_match = entry
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"[SemanticCache] 解析缓存失败: {key}, {e}")
continue
# 判断是否达到相似度阈值
if best_match and best_similarity >= self.SIMILARITY_THRESHOLD:
logger.info(f"[SemanticCache] 语义命中 (相似度={best_similarity:.2f}): {question[:30]}...")
# 更新命中次数
best_match.hit_count += 1
await redis.set(key, json.dumps(best_match.to_dict()), ex=self.CACHE_TTL_HOURS * 3600)
return best_match.answer, best_similarity
return None
except Exception as e:
logger.error(f"[SemanticCache] 查找失败: {e}")
return None
async def store(
self,
chat_id: str,
question: str,
answer: str,
redis_client=None
) -> bool:
"""
将问答对存入缓存
Args:
chat_id: 聊天会话ID
question: 用户问题
answer: RAG返回的答案
redis_client: Redis客户端可选
Returns:
是否存储成功
"""
redis = redis_client or self._get_redis()
if redis is None:
return False
try:
# 过滤无效答案
if not answer or len(answer.strip()) < 10:
logger.info(f"[SemanticCache] 答案太短,不缓存: {answer[:20]}...")
return False
# 构建缓存条目
question_hash = self._hash_question(question)
cache_key = self._build_cache_key(chat_id, question_hash)
entry = CacheEntry(
question=question,
answer=answer,
created_at=time.time(),
hit_count=0,
chat_id=chat_id
)
# 检查是否已存在(避免重复存储)
existing = await redis.get(cache_key)
if existing:
logger.debug(f"[SemanticCache] 问题已存在,跳过: {question[:30]}...")
return True
# 清理旧缓存如果该chat_id缓存过多
await self._cleanup_old_cache(chat_id, redis)
# 存储缓存
await redis.set(
cache_key,
json.dumps(entry.to_dict()),
ex=self.CACHE_TTL_HOURS * 3600
)
logger.info(f"[SemanticCache] 已缓存: {question[:30]}...")
return True
except Exception as e:
logger.error(f"[SemanticCache] 存储失败: {e}")
return False
async def _cleanup_old_cache(self, chat_id: str, redis_client=None) -> int:
"""
清理旧缓存当缓存过多时
策略删除最久未命中创建时间最早的缓存
Returns:
删除的缓存数量
"""
redis = redis_client or self._get_redis()
if redis is None:
return 0
try:
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
cache_keys = await redis.keys(pattern)
if len(cache_keys) <= self.MAX_CACHE_SIZE:
return 0
# 获取所有缓存的创建时间
cache_info = []
for key in cache_keys:
data = await redis.get(key)
if data:
try:
entry = CacheEntry.from_dict(json.loads(data))
cache_info.append((key, entry.created_at))
except:
pass
# 按创建时间排序,删除最旧的
cache_info.sort(key=lambda x: x[1])
delete_count = len(cache_keys) - self.MAX_CACHE_SIZE
deleted = 0
for key, _ in cache_info[:delete_count]:
await redis.delete(key)
deleted += 1
if deleted > 0:
logger.info(f"[SemanticCache] 清理了 {deleted} 条旧缓存")
return deleted
except Exception as e:
logger.error(f"[SemanticCache] 清理失败: {e}")
return 0
async def clear(self, chat_id: str = None, redis_client=None) -> bool:
"""
清除缓存
Args:
chat_id: 要清除的chat_id如果为None则清除所有
redis_client: Redis客户端可选
Returns:
是否清除成功
"""
redis = redis_client or self._get_redis()
if redis is None:
return False
try:
if chat_id:
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
else:
pattern = f"{self.CACHE_PREFIX}:*"
cache_keys = await redis.keys(pattern)
if cache_keys:
await redis.delete(*cache_keys)
logger.info(f"[SemanticCache] 清除 {len(cache_keys)} 条缓存")
return True
except Exception as e:
logger.error(f"[SemanticCache] 清除失败: {e}")
return False
async def get_stats(self, chat_id: str = None, redis_client=None) -> Dict[str, Any]:
"""
获取缓存统计信息
Args:
chat_id: 要统计的chat_id如果为None则统计所有
redis_client: Redis客户端可选
Returns:
统计信息字典
"""
redis = redis_client or self._get_redis()
if redis is None:
return {}
try:
if chat_id:
pattern = f"{self.CACHE_PREFIX}:{chat_id}:*"
else:
pattern = f"{self.CACHE_PREFIX}:*"
cache_keys = await redis.keys(pattern)
stats = {
"total_entries": len(cache_keys),
"chat_id": chat_id or "all",
"config": {
"max_cache_size": self.MAX_CACHE_SIZE,
"similarity_threshold": self.SIMILARITY_THRESHOLD,
"cache_ttl_hours": self.CACHE_TTL_HOURS
}
}
# 统计命中次数分布
if cache_keys:
total_hits = 0
for key in cache_keys[:100]: # 只采样前100个
data = await redis.get(key)
if data:
try:
entry = CacheEntry.from_dict(json.loads(data))
total_hits += entry.hit_count
except:
pass
stats["total_hits"] = total_hits
stats["avg_hits_per_entry"] = total_hits / min(len(cache_keys), 100)
return stats
except Exception as e:
logger.error(f"[SemanticCache] 统计失败: {e}")
return {}
# 全局服务实例
_semantic_cache_service: Optional[SemanticCacheService] = None
def get_semantic_cache_service() -> SemanticCacheService:
"""获取语义缓存服务单例"""
global _semantic_cache_service
if _semantic_cache_service is None:
_semantic_cache_service = SemanticCacheService()
return _semantic_cache_service
# 便捷函数
async def lookup_question(chat_id: str, question: str, redis_client=None) -> Optional[Tuple[str, float]]:
"""查找缓存问题"""
service = get_semantic_cache_service()
return await service.lookup(chat_id, question, redis_client)
async def store_qa_pair(chat_id: str, question: str, answer: str, redis_client=None) -> bool:
"""存储问答对"""
service = get_semantic_cache_service()
return await service.store(chat_id, question, answer, redis_client)
async def clear_cache(chat_id: str = None, redis_client=None) -> bool:
"""清除缓存"""
service = get_semantic_cache_service()
return await service.clear(chat_id, redis_client)