302 lines
9.3 KiB
Python
302 lines
9.3 KiB
Python
"""
|
||
测试语义缓存功能
|
||
|
||
功能:
|
||
1. 测试缓存的存储和查找
|
||
2. 测试语义匹配效果
|
||
3. 验证缓存性能
|
||
|
||
使用方式:
|
||
python test_semantic_cache.py
|
||
|
||
作者:AI Assistant
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import os
|
||
import time
|
||
|
||
# 添加项目根目录到Python路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from utils.semantic_cache_service import (
|
||
SemanticCacheService,
|
||
get_semantic_cache_service,
|
||
lookup_question,
|
||
store_qa_pair,
|
||
clear_cache
|
||
)
|
||
from config.get_redis import RedisUtil
|
||
|
||
|
||
# 全局Redis客户端
|
||
_redis_client = None
|
||
|
||
|
||
async def get_redis():
|
||
"""获取Redis客户端"""
|
||
global _redis_client
|
||
if _redis_client is None:
|
||
_redis_client = await RedisUtil.create_redis_pool()
|
||
return _redis_client
|
||
|
||
|
||
async def test_basic_functionality():
|
||
"""测试基本功能"""
|
||
print("\n" + "=" * 60)
|
||
print("测试1: 基本功能验证")
|
||
print("=" * 60)
|
||
|
||
# 获取Redis客户端
|
||
redis = await get_redis()
|
||
|
||
# 创建缓存服务实例
|
||
cache = SemanticCacheService(redis)
|
||
|
||
# 清理测试缓存
|
||
await cache.clear(chat_id="test_chat_001")
|
||
|
||
# 测试1: 精确匹配
|
||
print("\n1. 测试精确匹配...")
|
||
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
|
||
print(f" 首次查找(未命中): {result}")
|
||
assert result is None, "首次查找应该未命中"
|
||
|
||
# 存储问答对
|
||
success = await cache.store(
|
||
"test_chat_001",
|
||
"康达公司成立于哪一年?",
|
||
"康达公司成立于2012年,是一家专注于智能机器人研发的高科技企业。",
|
||
redis
|
||
)
|
||
print(f" 存储问答对: {'成功' if success else '失败'}")
|
||
|
||
# 再次查找
|
||
result = await cache.lookup("test_chat_001", "康达公司成立于哪一年?", redis)
|
||
print(f" 精确查找(已缓存): {'命中' if result else '未命中'}")
|
||
if result:
|
||
print(f" 答案: {result[0][:50]}...")
|
||
print(f" 相似度: {result[1]:.2f}")
|
||
assert result[1] == 1.0, "精确匹配相似度应该为1.0"
|
||
|
||
print(" ✓ 精确匹配测试通过")
|
||
|
||
|
||
async def test_semantic_matching():
|
||
"""测试语义匹配"""
|
||
print("\n" + "=" * 60)
|
||
print("测试2: 语义匹配验证")
|
||
print("=" * 60)
|
||
|
||
redis = await get_redis()
|
||
cache = SemanticCacheService(redis)
|
||
|
||
# 清理并重新测试
|
||
await cache.clear(chat_id="test_chat_002")
|
||
|
||
# 存储原始问答
|
||
original_question = "公司的退货政策是什么?"
|
||
original_answer = "我们支持7天无理由退换货,您可以在收到商品后7天内申请退换。"
|
||
await cache.store("test_chat_002", original_question, original_answer, redis)
|
||
|
||
# 测试各种问法
|
||
test_cases = [
|
||
("公司的退货政策是什么?", "原问题"),
|
||
("退货政策是怎么样的?", "同义词替换"),
|
||
("如果我想退货可以吗?", "语义相关"),
|
||
("你们怎么退货?", "口语化表达"),
|
||
("产品质量有问题怎么换?", "相关但不同"),
|
||
("今天天气怎么样?", "完全无关"),
|
||
]
|
||
|
||
print("\n语义匹配测试结果:")
|
||
print("-" * 60)
|
||
for question, desc in test_cases:
|
||
result = await cache.lookup("test_chat_002", question, redis)
|
||
if result:
|
||
answer, similarity = result
|
||
print(f" [{desc}]")
|
||
print(f" Q: {question[:30]}...")
|
||
print(f" 命中! 相似度={similarity:.2f}")
|
||
else:
|
||
print(f" [{desc}]")
|
||
print(f" Q: {question[:30]}...")
|
||
print(f" 未命中")
|
||
|
||
|
||
async def test_performance():
|
||
"""测试缓存性能"""
|
||
print("\n" + "=" * 60)
|
||
print("测试3: 缓存性能验证")
|
||
print("=" * 60)
|
||
|
||
redis = await get_redis()
|
||
cache = SemanticCacheService(redis)
|
||
|
||
# 清理并准备测试数据
|
||
await cache.clear(chat_id="test_chat_003")
|
||
|
||
# 预热缓存
|
||
questions = [
|
||
"你们公司提供哪些服务?",
|
||
"如何联系客服?",
|
||
"产品售后政策是什么?",
|
||
"发货需要多长时间?",
|
||
"支持哪些支付方式?",
|
||
]
|
||
|
||
print("\n预热缓存...")
|
||
start_time = time.time()
|
||
for i, q in enumerate(questions):
|
||
await cache.store(
|
||
"test_chat_003",
|
||
q,
|
||
f"这是针对问题的回答 #{i+1}",
|
||
redis
|
||
)
|
||
warm_time = time.time() - start_time
|
||
print(f" 存储5条问答耗时: {warm_time*1000:.2f}ms")
|
||
|
||
# 性能测试:精确匹配 vs 语义匹配
|
||
print("\n性能测试:")
|
||
|
||
# 精确匹配测试
|
||
start_time = time.time()
|
||
for _ in range(100):
|
||
await cache.lookup("test_chat_003", "你们公司提供哪些服务?", redis)
|
||
exact_time = time.time() - start_time
|
||
print(f" 100次精确匹配: {exact_time*1000:.2f}ms (平均{exact_time*10:.2f}ms/次)")
|
||
|
||
# 语义匹配测试(会扫描所有缓存)
|
||
start_time = time.time()
|
||
for _ in range(10):
|
||
await cache.lookup("test_chat_003", "公司提供什么服务?", redis)
|
||
semantic_time = time.time() - start_time
|
||
print(f" 10次语义匹配: {semantic_time*1000:.2f}ms (平均{semantic_time*100:.2f}ms/次)")
|
||
|
||
|
||
async def test_cache_statistics():
|
||
"""测试缓存统计"""
|
||
print("\n" + "=" * 60)
|
||
print("测试4: 缓存统计功能")
|
||
print("=" * 60)
|
||
|
||
redis = await get_redis()
|
||
cache = SemanticCacheService(redis)
|
||
|
||
# 清理并准备数据
|
||
await cache.clear(chat_id="test_chat_stats")
|
||
|
||
# 存储一些问答
|
||
await cache.store("test_chat_stats", "问题1", "答案1", redis)
|
||
await cache.store("test_chat_stats", "问题2", "答案2", redis)
|
||
|
||
# 命中几次
|
||
await cache.lookup("test_chat_stats", "问题1", redis)
|
||
await cache.lookup("test_chat_stats", "问题1", redis)
|
||
await cache.lookup("test_chat_stats", "问题2", redis)
|
||
|
||
# 获取统计
|
||
stats = await cache.get_stats("test_chat_stats", redis)
|
||
print(f"\n缓存统计信息:")
|
||
print(f" 总条目数: {stats.get('total_entries', 0)}")
|
||
print(f" 配置信息:")
|
||
print(f" - 最大缓存大小: {stats['config']['max_cache_size']}")
|
||
print(f" - 相似度阈值: {stats['config']['similarity_threshold']}")
|
||
print(f" - 缓存过期时间: {stats['config']['cache_ttl_hours']}小时")
|
||
|
||
|
||
async def test_integration_with_ragflow():
|
||
"""集成测试:模拟RAG流程"""
|
||
print("\n" + "=" * 60)
|
||
print("测试5: 集成测试 - 模拟RAG流程")
|
||
print("=" * 60)
|
||
|
||
redis = await get_redis()
|
||
cache = SemanticCacheService(redis)
|
||
|
||
# 模拟RAG服务
|
||
class MockRAGService:
|
||
async def query(self, question):
|
||
# 模拟RAG耗时
|
||
await asyncio.sleep(0.1)
|
||
return f"RAG生成的答案:{question}的详细回答..."
|
||
|
||
rag_service = MockRAGService()
|
||
|
||
# 清理测试缓存
|
||
await cache.clear(chat_id="test_integration")
|
||
|
||
test_questions = [
|
||
"康达机器人有哪些产品?",
|
||
"康达机器人有哪些产品?", # 重复问题,期望命中缓存
|
||
"你们的机器人产品有哪些?", # 相似问题,期望语义匹配
|
||
"产品价格是多少?", # 新问题,不命中
|
||
]
|
||
|
||
print("\n模拟用户连续提问:")
|
||
print("-" * 60)
|
||
|
||
for i, question in enumerate(test_questions):
|
||
print(f"\n用户提问 #{i+1}: {question[:30]}...")
|
||
|
||
# 1. 先查缓存
|
||
cached = await cache.lookup("test_integration", question, redis)
|
||
|
||
if cached:
|
||
answer, similarity = cached
|
||
print(f" ✓ 缓存命中! (相似度={similarity:.2f})")
|
||
print(f" → 直接返回缓存答案")
|
||
else:
|
||
print(f" ✗ 缓存未命中")
|
||
print(f" → 调用RAG服务...")
|
||
|
||
# 2. 调用RAG
|
||
answer = await rag_service.query(question)
|
||
print(f" → RAG返回答案")
|
||
|
||
# 3. 存储到缓存
|
||
await cache.store("test_integration", question, answer, redis)
|
||
print(f" → 已存入缓存")
|
||
|
||
print(f" 答案预览: {answer[:40]}...")
|
||
|
||
|
||
async def main():
|
||
"""主测试函数"""
|
||
print("\n" + "╔" + "═" * 58 + "╗")
|
||
print("║" + " " * 15 + "语义缓存功能测试" + " " * 20 + "║")
|
||
print("╚" + "═" * 58 + "╝")
|
||
|
||
try:
|
||
await test_basic_functionality()
|
||
await test_semantic_matching()
|
||
await test_performance()
|
||
await test_cache_statistics()
|
||
await test_integration_with_ragflow()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("✓ 所有测试完成!")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
finally:
|
||
# 清理测试数据
|
||
redis = await get_redis()
|
||
cache = SemanticCacheService(redis)
|
||
await cache.clear(chat_id="test_chat_001")
|
||
await cache.clear(chat_id="test_chat_002")
|
||
await cache.clear(chat_id="test_chat_003")
|
||
await cache.clear(chat_id="test_chat_stats")
|
||
await cache.clear(chat_id="test_integration")
|
||
print("\n测试数据已清理")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|