kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller_backup.py
2025-12-17 14:58:09 +08:00

376 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from datetime import datetime
import hashlib
import json
import time
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import StreamingResponse
# from pydantic_validation_decorator import ValidateFields
# from sqlalchemy.ext.asyncio import AsyncSession
# from config.enums import BusinessType
# from config.get_db import get_db
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
# from module_admin.entity.vo.user_vo import CurrentUserModel
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService, SearchServiceError
from utils.log_util import logger
# from utils.page_util import PageResponseModel
from utils.response_util import ResponseUtil
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
CreateSessionWithChatModel, ConverseWithChatAssistantModel
# from config.env import RAGFlowConfig
import asyncio
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
# 查看数据集列表
@ragflowController.post("/dataset_list"
# , response_model=PageResponseModel
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
)
async def get_system_ragflow_list(
request: Request,
rage_flow_dastset_query: RagflowListQueryModel ,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
return parse_result(result)
# 创建数据集
@ragflowController.post('/create_dataset')
async def create_dataset(
request: Request,
create_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
# 更新数据集
@ragflowController.post('/update_dataset/{dataset_id}')
async def update_dataset(
request: Request,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
return parse_result(result)
# 列出数据集中文档列表
@ragflowController.get("/list_documents/{dataset_id}")
async def list_documents_by_dataset_id(
request: Request,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
# query_db: AsyncSession = Depends(get_db),
):
"""
列出数据集中文档列表
"""
print(list_documents_query)
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
return parse_result(result)
# 上传文件到数据集
@ragflowController.post("/upload_file/{dataset_id}")
async def upload_file_dataset(
dataset_id: str,
files: List[UploadFile] = File(...),
# query_db: AsyncSession = Depends(get_db),
):
"""
上传文件到数据集
"""
# print(file)
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
return parse_result(result)
# 更新文档
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
async def update_file_dataset(
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
更新文件到数据集
"""
# print(file)
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
return parse_result(result)
# 开始解析文档
@ragflowController.post('/parse_documents/{dataset_id}')
async def parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 停止解析文档
@ragflowController.post('/stop_parse_documents/{dataset_id}')
async def stop_parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 删除文档
@ragflowController.post('/delete_file/{dataset_id}')
async def delete_file(
dataset_id: str,
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除文件
"""
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
return parse_result(result)
# 删除数据集
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除数据集
"""
result = await RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
# 查看聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel,
):
"""
查看聊天助手列表
"""
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# pass
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel,
):
"""
更新聊天助手
"""
result = await RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 创建属于聊天助手的会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel,
):
"""
创建属于聊天助手的会话
"""
result = await RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 与聊天助手进行对话
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话
"""
start_time = time.perf_counter()
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务 (Router Strategy)
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
t0 = time.perf_counter()
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
t1 = time.perf_counter()
# 使用同步Generator更简单
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
if converse_params.stream:
async def stream_response():
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
try:
# 将同步Generator转换为异步迭代
loop = asyncio.get_event_loop()
for chunk in result:
# 在事件循环中运行CPU密集型操作
await asyncio.sleep(0) # 让出控制权,允许其他协程运行
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
# Native RAGFlow returns cumulative text in 'answer' field
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
if isinstance(body, dict) and 'answer' in body:
# Inspect for errors even in stream
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"RAGFlow Stream Error: {chunk}")
current_answer = body['answer']
# Calculate Delta
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta # Send only the new part
last_answer = current_answer
yield format_sse(body)
else:
# Context reset (unlikely but safe fallback)
last_answer = current_answer
yield format_sse(body)
else:
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
return StreamingResponse(stream_response(), media_type='text/event-stream')
response = parse_result(result)
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
# return parse_result(result)
# 获取用户权限
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
"""
获取用户权限
"""
user_auth_list = current_user.permissions
print(user_auth_list)
return ResponseUtil.success(data=user_auth_list)
def parse_result(result):
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def format_sse(data: dict, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'
async def aiter_sync_generator(sync_gen):
"""将同步 Generator 转换为异步迭代器"""
loop = asyncio.get_event_loop()
try:
while True:
# 在线程池中获取下一个值
value = await loop.run_in_executor(None, next, sync_gen, None)
if value is None:
break
yield value
except StopIteration:
pass
import asyncio
import hashlib
import json
import time
from typing import AsyncIterator, Union
import aiofiles
import redis
from fastapi import Depends, Request, Response
from fastapi.responses import StreamingResponse
from fastapi_controller.decorators import Controller
from fastapi_controller.dependencies import LoginService
from fastapi_controller.response import ResponseUtil
from sqlalchemy.ext.asyncio import AsyncSession
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService
from module_admin.service.login_service import CheckRoleInterfaceAuth
from module_admin.model.ragflow_models import (
ConverseWithChatAssistantModel,
CreateSessionWithChatModel,
DeleteFileModel,
DocumentIdsModel,
RagflowListQueryModel,
UpdateChatAssistantModel,
)
import logger as logger_module