kangda-robot-backend/ruoyi-fastapi-backend/module_admin/controller/ragflow_controller.py

344 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from datetime import datetime
import hashlib
import json
import time
from typing import List
from fastapi import APIRouter, Depends, Request, UploadFile, File
from fastapi.responses import StreamingResponse
# from pydantic_validation_decorator import ValidateFields
# from sqlalchemy.ext.asyncio import AsyncSession
# from config.enums import BusinessType
# from config.get_db import get_db
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
# from module_admin.entity.vo.user_vo import CurrentUserModel
from module_admin.service.login_service import LoginService
from module_admin.service.ragflow_service import RAGFlowService
from module_admin.service.search_service import SearchService, SearchServiceError
from utils.log_util import logger
# from utils.page_util import PageResponseModel
from utils.response_util import ResponseUtil
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
CreateSessionWithChatModel, ConverseWithChatAssistantModel
# from config.env import RAGFlowConfig
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
# 查看数据集列表
@ragflowController.post("/dataset_list"
# , response_model=PageResponseModel
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
)
async def get_system_ragflow_list(
request: Request,
rage_flow_dastset_query: RagflowListQueryModel ,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
return parse_result(result)
# 创建数据集
@ragflowController.post('/create_dataset')
async def create_dataset(
request: Request,
create_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.create_dataset_services(create_dataset_params)
return parse_result(result)
# 更新数据集
@ragflowController.post('/update_dataset/{dataset_id}')
async def update_dataset(
request: Request,
dataset_id: str,
update_dataset_params: CreateDatasetModel,
):
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
return parse_result(result)
# 列出数据集中文档列表
@ragflowController.get("/list_documents/{dataset_id}")
async def list_documents_by_dataset_id(
request: Request,
dataset_id: str,
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
# query_db: AsyncSession = Depends(get_db),
):
"""
列出数据集中文档列表
"""
print(list_documents_query)
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
return parse_result(result)
# 上传文件到数据集
@ragflowController.post("/upload_file/{dataset_id}")
async def upload_file_dataset(
dataset_id: str,
files: List[UploadFile] = File(...),
# query_db: AsyncSession = Depends(get_db),
):
"""
上传文件到数据集
"""
# print(file)
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
return parse_result(result)
# 更新文档
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
async def update_file_dataset(
dataset_id: str,
document_id: str,
update_params: UpdateFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
更新文件到数据集
"""
# print(file)
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
return parse_result(result)
# 开始解析文档
@ragflowController.post('/parse_documents/{dataset_id}')
async def parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 停止解析文档
@ragflowController.post('/stop_parse_documents/{dataset_id}')
async def stop_parse_documents(
dataset_id: str,
parse_params: DocumentIdsModel,
# query_db: AsyncSession = Depends(get_db),
):
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
return parse_result(result)
# 删除文档
@ragflowController.post('/delete_file/{dataset_id}')
async def delete_file(
dataset_id: str,
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除文件
"""
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
return parse_result(result)
# 删除数据集
@ragflowController.post('/delete_datasets')
async def delete_datasets(
delete_params: DeleteFileModel,
# query_db: AsyncSession = Depends(get_db),
):
"""
删除数据集
"""
result = await RAGFlowService.delete_datasets_services(delete_params)
return parse_result(result)
# 查看聊天助手列表
@ragflowController.post('/get_chat_assistant_list')
async def get_chat_assistant_list(
query_params: RagflowListQueryModel,
):
"""
查看聊天助手列表
"""
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
return parse_result(result)
# pass
# 更新聊天助手
@ragflowController.post('/update_chat_assistant')
async def update_chat_assistant(
update_params: UpdateChatAssistantModel,
):
"""
更新聊天助手
"""
result = await RAGFlowService.update_chat_assistant_services(update_params)
return parse_result(result)
# 创建属于聊天助手的会话
@ragflowController.post('/create_session_with_chat')
async def create_session_with_chat(
create_params: CreateSessionWithChatModel,
):
"""
创建属于聊天助手的会话
"""
result = await RAGFlowService.create_session_with_chat_services(create_params)
return parse_result(result)
# 与聊天助手进行对话
@ragflowController.post('/converse_with_chat_assistant')
async def converse_with_chat_assistant(
request: Request,
converse_params: ConverseWithChatAssistantModel,
):
"""
与聊天助手进行对话
"""
start_time = time.perf_counter()
redis = getattr(request.app.state, 'redis', None)
cache_key = None
# 检查是否应该使用搜索服务
if SearchService.should_handle(converse_params.question):
try:
# 获取搜索结果
search_payload = await SearchService.get_search_answer(converse_params.question, redis=redis)
# 提取LLM上下文
context = search_payload.get('context_for_llm', '')
# 构建增强的提示词让RAGFlow的LLM基于搜索结果生成答案
enhanced_question = f"""请基于以下搜索结果,简洁准确地回答用户的问题。
用户问题:{converse_params.question}
搜索结果:
{context}
要求:
1. 直接回答问题,不要说"根据搜索结果"等前缀
2. 答案要简洁准确,突出重点
3. 如果是天气查询,只返回当前天气状况和温度
4. 如果是新闻查询,只返回最重要的新闻摘要
5. 用自然的语气回答,像人类助手一样"""
# 创建新的参数对象,使用增强的问题
from module_admin.entity.vo.ragflow_vo import ConverseWithChatAssistantModel
enhanced_params = ConverseWithChatAssistantModel(
chat_id=converse_params.chat_id,
question=enhanced_question,
stream=converse_params.stream
)
# 调用RAGFlow处理
result = await RAGFlowService.converse_with_chat_assistant_services(enhanced_params)
# 处理流式和非流式响应
if converse_params.stream:
async def search_stream():
try:
async for chunk in result:
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('搜索+LLM流式处理异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
logger.info('搜索+LLM处理耗时 %.3fs', time.perf_counter() - start_time)
return StreamingResponse(search_stream(), media_type='text/event-stream')
# 非流式响应
response = parse_result(result)
logger.info('搜索+LLM处理耗时 %.3fs', time.perf_counter() - start_time)
return response
except SearchServiceError as exc:
logger.warning('搜索服务失败降级到RAGFlow: %s', exc)
# 搜索失败时降级到普通RAGFlow处理
pass
if not converse_params.stream and redis:
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
cached = await redis.get(cache_key)
if cached:
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
return ResponseUtil.success(json.loads(cached))
result = await RAGFlowService.converse_with_chat_assistant_services(converse_params)
if converse_params.stream:
async def stream_response():
try:
async for chunk in result:
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
if not payload:
continue
body = payload if isinstance(payload, dict) else {'data': payload}
yield format_sse(body)
yield format_sse({'status': 'completed'}, event='end')
except Exception as exc:
logger.exception('ragflow流式对话异常: %s', exc)
yield format_sse({'message': str(exc)}, event='error')
finally:
logger.info('ragflow流式对话耗时 %.3fs', time.perf_counter() - start_time)
return StreamingResponse(stream_response(), media_type='text/event-stream')
response = parse_result(result)
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
logger.info('ragflow对话耗时 %.3fs', time.perf_counter() - start_time)
return response
# return parse_result(result)
# 获取用户权限
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
"""
获取用户权限
"""
user_auth_list = current_user.permissions
print(user_auth_list)
return ResponseUtil.success(data=user_auth_list)
def parse_result(result):
code = result.get('code', 0)
if code != 0:
return ResponseUtil.error(result.get('data', None))
return ResponseUtil.success(result.get('data', None))
def build_chat_cache_key(chat_id: str, question: str) -> str:
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
return f'ragflow:chat:{chat_id}:{digest}'
def format_sse(data: dict, event: str | None = None) -> str:
payload = json.dumps(data, ensure_ascii=False)
prefix = f'event: {event}\n' if event else ''
return f'{prefix}data: {payload}\n\n'