kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py

327 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from datetime import datetime
import hashlib
import json
import time
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import StreamingResponse
# from pydantic_validation_decorator import ValidateFields
# from sqlalchemy.ext.asyncio import AsyncSession
# from config.enums import BusinessType
# from config.get_db import get_db
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
# from module_admin.entity.vo.user_vo import CurrentUserModel
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService, SearchServiceError
from utils.log_util import logger
# from utils.page_util import PageResponseModel
from utils.response_util import ResponseUtil
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
CreateSessionWithChatModel, ConverseWithChatAssistantModel
# from config.env import RAGFlowConfig
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
# 查看数据集列表
@ragflowController.post("/dataset_list"
# , response_model=PageResponseModel
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
)
async def get_system_ragflow_list(
request: Request,
rage_flow_dastset_query: RagflowListQueryModel ,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
return parse_result(result)
# 创建数据集
@ragflowController.post('/create_dataset')
async def create_dataset(
request: Request,
create_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
# 更新数据集
@ragflowController.post('/update_dataset/{dataset_id}')
async def update_dataset(
request: Request,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
return parse_result(result)
# 列出数据集中文档列表
@ragflowController.get("/list_documents/{dataset_id}")
async def list_documents_by_dataset_id(
request: Request,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
# query_db: AsyncSession = Depends(get_db),
):
"""
列出数据集中文档列表
"""
print(list_documents_query)
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
return parse_result(result)
# 上传文件到数据集
@ragflowController.post("/upload_file/{dataset_id}")
async def upload_file_dataset(
dataset_id: str,
files: List[UploadFile] = File(...),
# query_db: AsyncSession = Depends(get_db),
):
"""
上传文件到数据集
"""
# print(file)
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
return parse_result(result)
# 更新文档
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
async def update_file_dataset(
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
更新文件到数据集
"""
# print(file)
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
return parse_result(result)
# 开始解析文档
@ragflowController.post('/parse_documents/{dataset_id}')
async def parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 停止解析文档
@ragflowController.post('/stop_parse_documents/{dataset_id}')
async def stop_parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 删除文档
@ragflowController.post('/delete_file/{dataset_id}')
async def delete_file(
dataset_id: str,
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除文件
"""
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
return parse_result(result)
# 删除数据集
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除数据集
"""
result = await RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
# 查看聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel,
):
"""
查看聊天助手列表
"""
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# pass
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel,
):
"""
更新聊天助手
"""
result = await RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 创建属于聊天助手的会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel,
):
"""
创建属于聊天助手的会话
"""
result = await RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 与聊天助手进行对话
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话
"""
start_time = time.perf_counter()
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务 (Router Strategy)
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
t0 = time.perf_counter()
intent = await SearchService.classify_intent(converse_params.question)
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
if intent == 'SEARCH':
return await SearchService.handle_search_chat(converse_params, redis)
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
t1 = time.perf_counter()
# 修复await async方法获取AsyncGenerator
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
if converse_params.stream:
async def stream_response():
last_answer = ""
first_token_received = False
start_stream_time = time.perf_counter()
try:
async for chunk in result:
if not first_token_received:
first_token_received = True
latency = time.perf_counter() - start_stream_time
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
# Native RAGFlow returns cumulative text in 'answer' field
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
if isinstance(body, dict) and 'answer' in body:
# Inspect for errors even in stream
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
logger.error(f"RAGFlow Stream Error: {chunk}")
current_answer = body['answer']
# Calculate Delta
if current_answer.startswith(last_answer):
delta = current_answer[len(last_answer):]
if delta:
body['answer'] = delta # Send only the new part
last_answer = current_answer
yield format_sse(body)
else:
# Context reset (unlikely but safe fallback)
last_answer = current_answer
yield format_sse(body)
else:
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
total_time = time.perf_counter() - start_time
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
return StreamingResponse(stream_response(), media_type='text/event-stream')
response = parse_result(result)
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
return response
# return parse_result(result)
# 获取用户权限
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
"""
获取用户权限
"""
user_auth_list = current_user.permissions
print(user_auth_list)
return ResponseUtil.success(data=user_auth_list)
def parse_result(result):
code = result.get('code', 0)
if code != 0:
msg = result.get('message') or result.get('msg') or '接口异常'
return ResponseUtil.error(msg=msg, data=result.get('data', None))
return ResponseUtil.success(data=result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def format_sse(data: dict, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'