This commit is contained in:
sladro 2026-02-02 19:16:16 +08:00
parent b685823869
commit 5d04dfb248

View File

@ -1,5 +1,6 @@
import hashlib
import json
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
@ -71,30 +72,18 @@ class KBESService:
chunks, metrics_raw = await cls._es_search(question)
retrieve_ms = (time.perf_counter() - t0) * 1000
match_service = get_match_service()
combined = "\n".join(
[
(c.get("question") or "")
+ "\n"
+ (c.get("answer") or "")
+ "\n"
+ (c.get("title") or "")
+ "\n"
+ (c.get("content") or "")
for c in chunks
]
)
coverage = match_service.calculate_keyword_coverage(question, combined)
chunks, best_cov = cls._rerank_chunks(question, chunks)
top1 = metrics_raw.get("top1_score", 0.0)
gap = metrics_raw.get("score_gap", 0.0)
top1 = float(chunks[0].get("_score") or 0.0) if chunks else 0.0
top2 = float(chunks[1].get("_score") or 0.0) if len(chunks) > 1 else 0.0
gap = top1 - top2
metrics = KBRetrievalMetrics(
retrieve_ms=retrieve_ms,
top1_score=top1,
score_gap=gap,
hit_count=len(chunks),
keyword_coverage=coverage,
keyword_coverage=best_cov,
cache_hit=cache_hit,
)
@ -123,6 +112,63 @@ class KBESService:
return chunks, metrics
@classmethod
def _rerank_chunks(cls, question: str, chunks: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]:
"""二次重排轻量、无LLM避免 ES top1 误命中导致直接输出低质答案。
目标尽量保持速度topK 很小只在内存里做简单特征打分
"""
if not chunks:
return chunks, 0.0
match_service = get_match_service()
q = (question or "").strip()
q_norm = q.lower()
# 约束词:如果用户问到公司/年份,命中片段最好也出现对应实体
must_brand = any(x in q_norm for x in ("康达新材", "康达"))
years = set(re.findall(r"20\d{2}", q))
max_es = max([float(c.get("_score") or 0.0) for c in chunks] or [0.0])
if max_es <= 0:
max_es = 1.0
scored: List[Tuple[float, float, float, Dict[str, Any]]] = []
for c in chunks:
text = "\n".join(
[
str(c.get("question") or ""),
str(c.get("answer") or ""),
str(c.get("title") or ""),
str(c.get("content") or ""),
str(c.get("source") or ""),
]
)
text_norm = text.lower()
cov = match_service.calculate_keyword_coverage(q, text)
es = float(c.get("_score") or 0.0)
es_norm = es / max_es
penalty = 0.0
if must_brand and ("康达" not in text_norm):
penalty -= 0.35
if years:
hit_years = any(y in text for y in years)
if not hit_years:
penalty -= 0.25
# 覆盖率更能反映“是否问对文档”ES 分数反映粗排相关性
final = 0.55 * cov + 0.45 * es_norm + penalty
scored.append((final, cov, es, c))
scored.sort(key=lambda x: x[0], reverse=True)
reranked = [c for _, _, _, c in scored]
best_cov = float(scored[0][1]) if scored else 0.0
return reranked, best_cov
@classmethod
def is_confident(cls, metrics: KBRetrievalMetrics) -> bool:
if metrics.hit_count <= 0: