kangda-robot-backend/ruoyi-fastapi-backend/test/test_semantic_cache.py

302 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试语义缓存功能
功能:
1. 测试缓存的存储和查找
2. 测试语义匹配效果
3. 验证缓存性能
使用方式:
python test_semantic_cache.py
作者AI Assistant
"""
import asyncio
import sys
import os
import time
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.semantic_cache_service import (
SemanticCacheService,
get_semantic_cache_service,
lookup_question,
store_qa_pair,
clear_cache
)
from config.get_redis import RedisUtil
# 全局Redis客户端
_redis_client = None
async def get_redis():
"""获取Redis客户端"""
global _redis_client
if _redis_client is None:
_redis_client = await RedisUtil.create_redis_pool()
return _redis_client
async def test_basic_functionality():
"""测试基本功能"""
print("\n" + "=" * 60)
print("测试1: 基本功能验证")
print("=" * 60)
# 获取Redis客户端
redis = await get_redis()
# 创建缓存服务实例
cache = SemanticCacheService(redis)
# 清理测试缓存
await cache.clear(chat_id="test_chat_001")
# 测试1: 精确匹配
print("\n1. 测试精确匹配...")
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
print(f" 首次查找(未命中): {result}")
assert result is None, "首次查找应该未命中"
# 存储问答对
success = await cache.store(
"test_chat_001",
"康达公司成立于哪一年?",
"康达公司成立于2012年是一家专注于智能机器人研发的高科技企业。",
redis
)
print(f" 存储问答对: {'成功' if success else '失败'}")
# 再次查找
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
print(f" 精确查找(已缓存): {'命中' if result else '未命中'}")
if result:
print(f" 答案: {result[0][:50]}...")
print(f" 相似度: {result[1]:.2f}")
assert result[1] == 1.0, "精确匹配相似度应该为1.0"
print(" ✓ 精确匹配测试通过")
async def test_semantic_matching():
"""测试语义匹配"""
print("\n" + "=" * 60)
print("测试2: 语义匹配验证")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并重新测试
await cache.clear(chat_id="test_chat_002")
# 存储原始问答
original_question = "公司的退货政策是什么?"
original_answer = "我们支持7天无理由退换货您可以在收到商品后7天内申请退换。"
await cache.store("test_chat_002", original_question, original_answer, redis)
# 测试各种问法
test_cases = [
("公司的退货政策是什么?", "原问题"),
("退货政策是怎么样的?", "同义词替换"),
("如果我想退货可以吗?", "语义相关"),
("你们怎么退货?", "口语化表达"),
("产品质量有问题怎么换?", "相关但不同"),
("今天天气怎么样?", "完全无关"),
]
print("\n语义匹配测试结果:")
print("-" * 60)
for question, desc in test_cases:
result = await cache.lookup("test_chat_002", question, redis)
if result:
answer, similarity = result
print(f" [{desc}]")
print(f" Q: {question[:30]}...")
print(f" 命中! 相似度={similarity:.2f}")
else:
print(f" [{desc}]")
print(f" Q: {question[:30]}...")
print(f" 未命中")
async def test_performance():
"""测试缓存性能"""
print("\n" + "=" * 60)
print("测试3: 缓存性能验证")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并准备测试数据
await cache.clear(chat_id="test_chat_003")
# 预热缓存
questions = [
"你们公司提供哪些服务?",
"如何联系客服?",
"产品售后政策是什么?",
"发货需要多长时间?",
"支持哪些支付方式?",
]
print("\n预热缓存...")
start_time = time.time()
for i, q in enumerate(questions):
await cache.store(
"test_chat_003",
q,
f"这是针对问题的回答 #{i+1}",
redis
)
warm_time = time.time() - start_time
print(f" 存储5条问答耗时: {warm_time*1000:.2f}ms")
# 性能测试:精确匹配 vs 语义匹配
print("\n性能测试:")
# 精确匹配测试
start_time = time.time()
for _ in range(100):
await cache.lookup("test_chat_003", "你们公司提供哪些服务?", redis)
exact_time = time.time() - start_time
print(f" 100次精确匹配: {exact_time*1000:.2f}ms (平均{exact_time*10:.2f}ms/次)")
# 语义匹配测试(会扫描所有缓存)
start_time = time.time()
for _ in range(10):
await cache.lookup("test_chat_003", "公司提供什么服务?", redis)
semantic_time = time.time() - start_time
print(f" 10次语义匹配: {semantic_time*1000:.2f}ms (平均{semantic_time*100:.2f}ms/次)")
async def test_cache_statistics():
"""测试缓存统计"""
print("\n" + "=" * 60)
print("测试4: 缓存统计功能")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 清理并准备数据
await cache.clear(chat_id="test_chat_stats")
# 存储一些问答
await cache.store("test_chat_stats", "问题1", "答案1", redis)
await cache.store("test_chat_stats", "问题2", "答案2", redis)
# 命中几次
await cache.lookup("test_chat_stats", "问题1", redis)
await cache.lookup("test_chat_stats", "问题1", redis)
await cache.lookup("test_chat_stats", "问题2", redis)
# 获取统计
stats = await cache.get_stats("test_chat_stats", redis)
print(f"\n缓存统计信息:")
print(f" 总条目数: {stats.get('total_entries', 0)}")
print(f" 配置信息:")
print(f" - 最大缓存大小: {stats['config']['max_cache_size']}")
print(f" - 相似度阈值: {stats['config']['similarity_threshold']}")
print(f" - 缓存过期时间: {stats['config']['cache_ttl_hours']}小时")
async def test_integration_with_ragflow():
"""集成测试模拟RAG流程"""
print("\n" + "=" * 60)
print("测试5: 集成测试 - 模拟RAG流程")
print("=" * 60)
redis = await get_redis()
cache = SemanticCacheService(redis)
# 模拟RAG服务
class MockRAGService:
async def query(self, question):
# 模拟RAG耗时
await asyncio.sleep(0.1)
return f"RAG生成的答案{question}的详细回答..."
rag_service = MockRAGService()
# 清理测试缓存
await cache.clear(chat_id="test_integration")
test_questions = [
"康达机器人有哪些产品?",
"康达机器人有哪些产品?", # 重复问题,期望命中缓存
"你们的机器人产品有哪些?", # 相似问题,期望语义匹配
"产品价格是多少?", # 新问题,不命中
]
print("\n模拟用户连续提问:")
print("-" * 60)
for i, question in enumerate(test_questions):
print(f"\n用户提问 #{i+1}: {question[:30]}...")
# 1. 先查缓存
cached = await cache.lookup("test_integration", question, redis)
if cached:
answer, similarity = cached
print(f" ✓ 缓存命中! (相似度={similarity:.2f})")
print(f" → 直接返回缓存答案")
else:
print(f" ✗ 缓存未命中")
print(f" → 调用RAG服务...")
# 2. 调用RAG
answer = await rag_service.query(question)
print(f" → RAG返回答案")
# 3. 存储到缓存
await cache.store("test_integration", question, answer, redis)
print(f" → 已存入缓存")
print(f" 答案预览: {answer[:40]}...")
async def main():
"""主测试函数"""
print("\n" + "" + "" * 58 + "")
print("" + " " * 15 + "语义缓存功能测试" + " " * 20 + "")
print("" + "" * 58 + "")
try:
await test_basic_functionality()
await test_semantic_matching()
await test_performance()
await test_cache_statistics()
await test_integration_with_ragflow()
print("\n" + "=" * 60)
print("✓ 所有测试完成!")
print("=" * 60)
except Exception as e:
print(f"\n✗ 测试失败: {e}")
import traceback
traceback.print_exc()
finally:
# 清理测试数据
redis = await get_redis()
cache = SemanticCacheService(redis)
await cache.clear(chat_id="test_chat_001")
await cache.clear(chat_id="test_chat_002")
await cache.clear(chat_id="test_chat_003")
await cache.clear(chat_id="test_chat_stats")
await cache.clear(chat_id="test_integration")
print("\n测试数据已清理")
if __name__ == "__main__":
asyncio.run(main())