改成异步生成器

This commit is contained in:
Tian jianyong 2025-12-24 20:27:11 +08:00
parent ac756fb4ee
commit 57406f922e

View File

@ -6,9 +6,9 @@ import asyncio
import hashlib
import json
import time
from typing import Any, Dict, Generator, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Tuple
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse
from module_admin.service.login_service import LoginService
@ -96,7 +96,6 @@ def remove_style_hint(question: str) -> Tuple[str, bool]:
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
background_tasks: BackgroundTasks,
converse_params: ConverseWithChatAssistantModel,
):
"""
@ -203,21 +202,16 @@ async def converse_with_chat_assistant(
store_hash = get_question_hash(cache_question)
logger.info(f'[RAG_CACHE] 准备存储 | chat_id={converse_params.chat_id} | question={cache_question} | hash={store_hash}')
answer_queue = asyncio.Queue()
async def cache_store_background():
try:
answer = await answer_queue.get()
if redis and answer and len(answer.strip()) >= 10:
async def cache_store_func(answer: str):
if redis and answer and len(answer.strip()) >= 10:
try:
await _async_store_qa(converse_params.chat_id, cache_question, answer, redis)
logger.info(f'[RAG_CACHE] 语义缓存存储成功 | chat_id={converse_params.chat_id} | answer_length={len(answer)}')
except Exception as e:
logger.warning(f'[RAG_CACHE] 语义缓存存储失败: {e}')
background_tasks.add_task(cache_store_background)
except Exception as e:
logger.warning(f'[RAG_CACHE] 语义缓存存储失败: {e}')
return StreamingResponse(
stream_ragflow_response(result, converse_params.chat_id, start_time, answer_queue=answer_queue),
stream_ragflow_response(result, converse_params.chat_id, start_time, cache_store_func=cache_store_func),
media_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
@ -232,9 +226,9 @@ async def converse_with_chat_assistant(
return ResponseUtil.error(msg=f'对话服务异常: {str(e)}')
def stream_ragflow_response(result: Generator, chat_id: str, start_time: float, answer_queue=None) -> Generator[str, None, None]:
async def stream_ragflow_response(result: Generator, chat_id: str, start_time: float, cache_store_func=None) -> AsyncGenerator[str, None]:
"""
流式处理RAGFlow响应 - 步版本修复首token延迟
流式处理RAGFlow响应 - 步版本修复首token延迟
"""
import time
@ -325,17 +319,12 @@ def stream_ragflow_response(result: Generator, chat_id: str, start_time: float,
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_SERVER {time.time():.3f}] ⏱️ Total Stream Duration ({chat_id}): {total_time:.3f}s')
if answer_queue and last_answer and len(last_answer.strip()) >= 10:
if cache_store_func and last_answer and len(last_answer.strip()) >= 10:
try:
import asyncio
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.run_coroutine_threadsafe(answer_queue.put(last_answer), loop)
else:
loop.run_until_complete(answer_queue.put(last_answer))
logger.info(f'[RAG_CACHE] 已将答案放入队列 | chat_id={chat_id} | answer_length={len(last_answer)}')
await cache_store_func(last_answer)
logger.info(f'[RAG_CACHE] 缓存存储完成 | chat_id={chat_id} | answer_length={len(last_answer)}')
except Exception as cache_err:
logger.warning(f'[RAG_CACHE] 队列写入失败: {cache_err}')
logger.warning(f'[RAG_CACHE] 缓存存储失败: {cache_err}')
# 聊天助手列表