diff --git a/ruoyi-fastapi-backend/module_admin/service/kb_es_service.py b/ruoyi-fastapi-backend/module_admin/service/kb_es_service.py index eb9826b..b78ad6a 100644 --- a/ruoyi-fastapi-backend/module_admin/service/kb_es_service.py +++ b/ruoyi-fastapi-backend/module_admin/service/kb_es_service.py @@ -1,5 +1,6 @@ import hashlib import json +import re import time from dataclasses import dataclass from typing import Any, Dict, List, Tuple @@ -71,30 +72,18 @@ class KBESService: chunks, metrics_raw = await cls._es_search(question) retrieve_ms = (time.perf_counter() - t0) * 1000 - match_service = get_match_service() - combined = "\n".join( - [ - (c.get("question") or "") - + "\n" - + (c.get("answer") or "") - + "\n" - + (c.get("title") or "") - + "\n" - + (c.get("content") or "") - for c in chunks - ] - ) - coverage = match_service.calculate_keyword_coverage(question, combined) + chunks, best_cov = cls._rerank_chunks(question, chunks) - top1 = metrics_raw.get("top1_score", 0.0) - gap = metrics_raw.get("score_gap", 0.0) + top1 = float(chunks[0].get("_score") or 0.0) if chunks else 0.0 + top2 = float(chunks[1].get("_score") or 0.0) if len(chunks) > 1 else 0.0 + gap = top1 - top2 metrics = KBRetrievalMetrics( retrieve_ms=retrieve_ms, top1_score=top1, score_gap=gap, hit_count=len(chunks), - keyword_coverage=coverage, + keyword_coverage=best_cov, cache_hit=cache_hit, ) @@ -123,6 +112,63 @@ class KBESService: return chunks, metrics + @classmethod + def _rerank_chunks(cls, question: str, chunks: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], float]: + """二次重排(轻量、无LLM),避免 ES top1 误命中导致直接输出低质答案。 + + 目标:尽量保持速度(topK 很小),只在内存里做简单特征打分。 + """ + + if not chunks: + return chunks, 0.0 + + match_service = get_match_service() + q = (question or "").strip() + q_norm = q.lower() + + # 约束词:如果用户问到公司/年份,命中片段最好也出现对应实体 + must_brand = any(x in q_norm for x in ("康达新材", "康达")) + years = set(re.findall(r"20\d{2}", q)) + + max_es = max([float(c.get("_score") or 0.0) for c in chunks] or [0.0]) + if max_es <= 0: + max_es = 1.0 + + scored: List[Tuple[float, float, float, Dict[str, Any]]] = [] + for c in chunks: + text = "\n".join( + [ + str(c.get("question") or ""), + str(c.get("answer") or ""), + str(c.get("title") or ""), + str(c.get("content") or ""), + str(c.get("source") or ""), + ] + ) + text_norm = text.lower() + + cov = match_service.calculate_keyword_coverage(q, text) + es = float(c.get("_score") or 0.0) + es_norm = es / max_es + + penalty = 0.0 + if must_brand and ("康达" not in text_norm): + penalty -= 0.35 + if years: + hit_years = any(y in text for y in years) + if not hit_years: + penalty -= 0.25 + + # 覆盖率更能反映“是否问对文档”,ES 分数反映粗排相关性 + final = 0.55 * cov + 0.45 * es_norm + penalty + + scored.append((final, cov, es, c)) + + scored.sort(key=lambda x: x[0], reverse=True) + reranked = [c for _, _, _, c in scored] + best_cov = float(scored[0][1]) if scored else 0.0 + return reranked, best_cov + @classmethod def is_confident(cls, metrics: KBRetrievalMetrics) -> bool: if metrics.hit_count <= 0: