update
This commit is contained in:
parent
b685823869
commit
5d04dfb248
@ -1,5 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -71,30 +72,18 @@ class KBESService:
|
||||
chunks, metrics_raw = await cls._es_search(question)
|
||||
retrieve_ms = (time.perf_counter() - t0) * 1000
|
||||
|
||||
match_service = get_match_service()
|
||||
combined = "\n".join(
|
||||
[
|
||||
(c.get("question") or "")
|
||||
+ "\n"
|
||||
+ (c.get("answer") or "")
|
||||
+ "\n"
|
||||
+ (c.get("title") or "")
|
||||
+ "\n"
|
||||
+ (c.get("content") or "")
|
||||
for c in chunks
|
||||
]
|
||||
)
|
||||
coverage = match_service.calculate_keyword_coverage(question, combined)
|
||||
chunks, best_cov = cls._rerank_chunks(question, chunks)
|
||||
|
||||
top1 = metrics_raw.get("top1_score", 0.0)
|
||||
gap = metrics_raw.get("score_gap", 0.0)
|
||||
top1 = float(chunks[0].get("_score") or 0.0) if chunks else 0.0
|
||||
top2 = float(chunks[1].get("_score") or 0.0) if len(chunks) > 1 else 0.0
|
||||
gap = top1 - top2
|
||||
|
||||
metrics = KBRetrievalMetrics(
|
||||
retrieve_ms=retrieve_ms,
|
||||
top1_score=top1,
|
||||
score_gap=gap,
|
||||
hit_count=len(chunks),
|
||||
keyword_coverage=coverage,
|
||||
keyword_coverage=best_cov,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
|
||||
@ -123,6 +112,63 @@ class KBESService:
|
||||
|
||||
return chunks, metrics
|
||||
|
||||
@classmethod
|
||||
def _rerank_chunks(cls, question: str, chunks: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]:
|
||||
"""二次重排(轻量、无LLM),避免 ES top1 误命中导致直接输出低质答案。
|
||||
|
||||
目标:尽量保持速度(topK 很小),只在内存里做简单特征打分。
|
||||
"""
|
||||
|
||||
if not chunks:
|
||||
return chunks, 0.0
|
||||
|
||||
match_service = get_match_service()
|
||||
q = (question or "").strip()
|
||||
q_norm = q.lower()
|
||||
|
||||
# 约束词:如果用户问到公司/年份,命中片段最好也出现对应实体
|
||||
must_brand = any(x in q_norm for x in ("康达新材", "康达"))
|
||||
years = set(re.findall(r"20\d{2}", q))
|
||||
|
||||
max_es = max([float(c.get("_score") or 0.0) for c in chunks] or [0.0])
|
||||
if max_es <= 0:
|
||||
max_es = 1.0
|
||||
|
||||
scored: List[Tuple[float, float, float, Dict[str, Any]]] = []
|
||||
for c in chunks:
|
||||
text = "\n".join(
|
||||
[
|
||||
str(c.get("question") or ""),
|
||||
str(c.get("answer") or ""),
|
||||
str(c.get("title") or ""),
|
||||
str(c.get("content") or ""),
|
||||
str(c.get("source") or ""),
|
||||
]
|
||||
)
|
||||
text_norm = text.lower()
|
||||
|
||||
cov = match_service.calculate_keyword_coverage(q, text)
|
||||
es = float(c.get("_score") or 0.0)
|
||||
es_norm = es / max_es
|
||||
|
||||
penalty = 0.0
|
||||
if must_brand and ("康达" not in text_norm):
|
||||
penalty -= 0.35
|
||||
if years:
|
||||
hit_years = any(y in text for y in years)
|
||||
if not hit_years:
|
||||
penalty -= 0.25
|
||||
|
||||
# 覆盖率更能反映“是否问对文档”,ES 分数反映粗排相关性
|
||||
final = 0.55 * cov + 0.45 * es_norm + penalty
|
||||
|
||||
scored.append((final, cov, es, c))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
reranked = [c for _, _, _, c in scored]
|
||||
best_cov = float(scored[0][1]) if scored else 0.0
|
||||
return reranked, best_cov
|
||||
|
||||
@classmethod
|
||||
def is_confident(cls, metrics: KBRetrievalMetrics) -> bool:
|
||||
if metrics.hit_count <= 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user