修改相似度阈值到 0.6
This commit is contained in:
parent
31c183c906
commit
acd8af816f
@ -2,8 +2,11 @@ APScheduler==3.11.0
|
||||
asyncpg==0.30.0
|
||||
DateTime==5.5
|
||||
fastapi[all]==0.115.8
|
||||
httpx==0.27.2
|
||||
jieba==0.42.1
|
||||
loguru==0.7.3
|
||||
openpyxl==3.1.5
|
||||
openai==1.37.1
|
||||
pandas==2.2.3
|
||||
passlib[bcrypt]==1.7.4
|
||||
Pillow==11.1.0
|
||||
99
ruoyi-fastapi-backend/test/test_comprehensive_similarity.py
Normal file
99
ruoyi-fastapi-backend/test/test_comprehensive_similarity.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
更全面的语义缓存相似度测试
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils.semantic_cache_service import SemanticCacheService
|
||||
|
||||
def test_comprehensive():
|
||||
"""全面的相似度测试"""
|
||||
service = SemanticCacheService()
|
||||
|
||||
# 定义测试用例:[问题1, 问题2, 期望结果(应该相似/应该不相似), 简要说明]
|
||||
test_cases = [
|
||||
# ===== 产品相关 =====
|
||||
("公司有什么产品?", "公司有哪些产品?", True, "产品 - 有什么vs有哪些"),
|
||||
("你们有哪些产品?", "公司有什么产品?", True, "产品 - 你们vs公司"),
|
||||
("介绍一下你们的产品", "产品有哪些?", True, "产品 - 介绍vs哪些"),
|
||||
("我想买你们的东西", "公司有什么产品?", False, "产品 - 语义不同"),
|
||||
|
||||
# ===== 价格相关 =====
|
||||
("这个产品多少钱?", "产品价格是多少?", True, "价格 - 多少钱vs多少"),
|
||||
("费用怎么算?", "需要付多少钱?", True, "价格 - 费用vs钱"),
|
||||
("产品免费吗?", "多少钱?", False, "价格 - 免费vs付费"),
|
||||
|
||||
# ===== 联系方式相关 =====
|
||||
("怎么联系你们?", "电话号码是多少?", True, "联系 - 电话vs联系"),
|
||||
("公司地址在哪?", "怎么去你们公司?", True, "地址 - 位置vs路线"),
|
||||
("联系邮箱", "你们有客服吗?", False, "联系 - 邮箱vs客服"),
|
||||
|
||||
# ===== 服务相关 =====
|
||||
("提供哪些服务?", "你们能做什么?", True, "服务 - 能做什么"),
|
||||
("售后服务怎么样?", "有保修吗?", True, "服务 - 售后vs保修"),
|
||||
("能开发票吗?", "支持报销吗?", True, "服务 - 发票vs报销"),
|
||||
|
||||
# ===== 时间相关 =====
|
||||
("什么时候上班?", "营业时间几点到几点?", True, "时间 - 营业时间"),
|
||||
("几点下班?", "什么时候关门?", True, "时间 - 下班vs关门"),
|
||||
("今天星期几?", "什么时候发货?", False, "时间 - 日期vs发货"),
|
||||
|
||||
# ===== 复杂问法 =====
|
||||
("我想问一下,就是那个关于产品的问题", "公司有什么产品?", True, "复杂 - 口语化表达"),
|
||||
("请问贵公司目前主推的产品型号有哪些?", "有什么产品?", True, "复杂 - 正式表达"),
|
||||
|
||||
# ===== 边界情况 =====
|
||||
("", "公司有什么产品?", False, "空问题"),
|
||||
("公司", "公司有什么产品?", False, "问题太短"),
|
||||
("公司有什么产品?公司有什么产品?", "公司有什么产品?", True, "重复问题"),
|
||||
]
|
||||
|
||||
print("🔍 语义相似度全面测试")
|
||||
print("=" * 80)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for q1, q2, should_match, description in test_cases:
|
||||
similarity = service._calculate_text_similarity(q1, q2)
|
||||
keywords1 = service._extract_keywords(q1)
|
||||
keywords2 = service._extract_keywords(q2)
|
||||
|
||||
threshold = service.SIMILARITY_THRESHOLD
|
||||
actual_match = similarity >= threshold
|
||||
|
||||
# 判断是否正确
|
||||
if actual_match == should_match:
|
||||
status = "✅ 通过"
|
||||
passed += 1
|
||||
else:
|
||||
status = "❌ 失败"
|
||||
failed += 1
|
||||
|
||||
match_type = "命中" if actual_match else "未命中"
|
||||
expected = "应该命中" if should_match else "应该不命中"
|
||||
|
||||
print(f"\n{description}")
|
||||
print(f" Q1: {q1 or '(空)'}")
|
||||
print(f" Q2: {q2 or '(空)'}")
|
||||
print(f" 关键词1: {keywords1}")
|
||||
print(f" 关键词2: {keywords2}")
|
||||
print(f" 相似度: {similarity:.3f} (阈值: {threshold})")
|
||||
print(f" {status} | {match_type} | {expected}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"📊 测试结果: {passed} 通过, {failed} 失败, 总计 {len(test_cases)}")
|
||||
print(f" 通过率: {passed/len(test_cases)*100:.1f}%")
|
||||
|
||||
# 统计分析
|
||||
print("\n📈 相似度分布:")
|
||||
similarities = [service._calculate_text_similarity(q1, q2) for q1, q2, _, _ in test_cases]
|
||||
print(f" 最高: {max(similarities):.3f}")
|
||||
print(f" 最低: {min(similarities):.3f}")
|
||||
print(f" 平均: {sum(similarities)/len(similarities):.3f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_comprehensive()
|
||||
@ -67,7 +67,7 @@ class SemanticCacheService:
|
||||
# 缓存配置
|
||||
CACHE_PREFIX = "rag:semantic:cache"
|
||||
MAX_CACHE_SIZE = 1000 # 每个chat_id最大缓存条数
|
||||
SIMILARITY_THRESHOLD = 0.75 # 相似度阈值
|
||||
SIMILARITY_THRESHOLD = 0.6 # 相似度阈值
|
||||
CACHE_TTL_HOURS = 24 # 缓存过期时间(小时)
|
||||
|
||||
def __init__(self, redis_client=None):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user