376 lines
13 KiB
Python
376 lines
13 KiB
Python
# from datetime import datetime
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from typing import List
|
||
from fastapi import APIRouter, Depends, Request, UploadFile, File
|
||
from fastapi.responses import StreamingResponse
|
||
# from pydantic_validation_decorator import ValidateFields
|
||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||
# from config.enums import BusinessType
|
||
# from config.get_db import get_db
|
||
from module_admin.aspect.interface_auth import CheckRoleInterfaceAuth
|
||
# from module_admin.entity.vo.notice_vo import DeleteNoticeModel, NoticeModel, NoticePageQueryModel
|
||
# from module_admin.entity.vo.user_vo import CurrentUserModel
|
||
from module_admin.service.login_service import LoginService
|
||
from module_admin.service.ragflow_service import RAGFlowService
|
||
from module_admin.service.search_service import SearchService, SearchServiceError
|
||
|
||
from utils.log_util import logger
|
||
# from utils.page_util import PageResponseModel
|
||
from utils.response_util import ResponseUtil
|
||
from module_admin.entity.vo.ragflow_vo import RagflowListQueryModel, ListDocumentsQueryModel, UpdateFileModel, DeleteFileModel, CreateDatasetModel, DocumentIdsModel, UpdateChatAssistantModel,\
|
||
CreateSessionWithChatModel, ConverseWithChatAssistantModel
|
||
# from config.env import RAGFlowConfig
|
||
import asyncio
|
||
|
||
|
||
|
||
ragflowController = APIRouter(prefix="/system/ragflow", dependencies=[Depends(LoginService.get_current_user)])
|
||
|
||
|
||
|
||
# 查看数据集列表
|
||
@ragflowController.post("/dataset_list"
|
||
# , response_model=PageResponseModel
|
||
# , dependencies=[Depends(CheckUserInterfaceAuth("system:ragflow:list"))]"
|
||
)
|
||
async def get_system_ragflow_list(
|
||
request: Request,
|
||
rage_flow_dastset_query: RagflowListQueryModel ,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
|
||
result = await RAGFlowService.get_ragflow_dataset_list_services(None, rage_flow_dastset_query)
|
||
|
||
return parse_result(result)
|
||
|
||
# 创建数据集
|
||
@ragflowController.post('/create_dataset')
|
||
async def create_dataset(
|
||
request: Request,
|
||
create_dataset_params: CreateDatasetModel,
|
||
):
|
||
|
||
result = await RAGFlowService.create_dataset_services(create_dataset_params)
|
||
return parse_result(result)
|
||
|
||
# 更新数据集
|
||
@ragflowController.post('/update_dataset/{dataset_id}')
|
||
async def update_dataset(
|
||
request: Request,
|
||
dataset_id: str,
|
||
update_dataset_params: CreateDatasetModel,
|
||
):
|
||
result = await RAGFlowService.update_dataset_services(dataset_id, update_dataset_params)
|
||
return parse_result(result)
|
||
|
||
|
||
# 列出数据集中文档列表
|
||
@ragflowController.get("/list_documents/{dataset_id}")
|
||
async def list_documents_by_dataset_id(
|
||
request: Request,
|
||
dataset_id: str,
|
||
list_documents_query: ListDocumentsQueryModel = Depends(ListDocumentsQueryModel.as_query),
|
||
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
列出数据集中文档列表
|
||
"""
|
||
print(list_documents_query)
|
||
result = await RAGFlowService.list_documents_services(None, dataset_id, list_documents_query)
|
||
|
||
return parse_result(result)
|
||
|
||
|
||
# 上传文件到数据集
|
||
@ragflowController.post("/upload_file/{dataset_id}")
|
||
async def upload_file_dataset(
|
||
dataset_id: str,
|
||
files: List[UploadFile] = File(...),
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
上传文件到数据集
|
||
"""
|
||
# print(file)
|
||
result = await RAGFlowService.upload_file_dataset_services(None, dataset_id ,files)
|
||
|
||
return parse_result(result)
|
||
|
||
# 更新文档
|
||
@ragflowController.post("/update_file/{dataset_id}/{document_id}")
|
||
async def update_file_dataset(
|
||
dataset_id: str,
|
||
document_id: str,
|
||
update_params: UpdateFileModel,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
更新文件到数据集
|
||
"""
|
||
# print(file)
|
||
result = await RAGFlowService.update_file_dataset_services(dataset_id ,document_id, update_params)
|
||
|
||
return parse_result(result)
|
||
|
||
# 开始解析文档
|
||
@ragflowController.post('/parse_documents/{dataset_id}')
|
||
async def parse_documents(
|
||
dataset_id: str,
|
||
parse_params: DocumentIdsModel,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
result = await RAGFlowService.parse_documents_services(dataset_id, parse_params)
|
||
return parse_result(result)
|
||
|
||
# 停止解析文档
|
||
@ragflowController.post('/stop_parse_documents/{dataset_id}')
|
||
async def stop_parse_documents(
|
||
dataset_id: str,
|
||
parse_params: DocumentIdsModel,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
result = await RAGFlowService.stop_parse_documents_services(dataset_id, parse_params)
|
||
return parse_result(result)
|
||
|
||
# 删除文档
|
||
@ragflowController.post('/delete_file/{dataset_id}')
|
||
async def delete_file(
|
||
dataset_id: str,
|
||
delete_params: DeleteFileModel,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
删除文件
|
||
"""
|
||
result = await RAGFlowService.delete_file_services(dataset_id, delete_params)
|
||
|
||
return parse_result(result)
|
||
|
||
|
||
# 删除数据集
|
||
@ragflowController.post('/delete_datasets')
|
||
async def delete_datasets(
|
||
delete_params: DeleteFileModel,
|
||
# query_db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
删除数据集
|
||
"""
|
||
result = await RAGFlowService.delete_datasets_services(delete_params)
|
||
return parse_result(result)
|
||
|
||
# 查看聊天助手列表
|
||
@ragflowController.post('/get_chat_assistant_list')
|
||
async def get_chat_assistant_list(
|
||
query_params: RagflowListQueryModel,
|
||
):
|
||
"""
|
||
查看聊天助手列表
|
||
"""
|
||
|
||
result = await RAGFlowService.get_chat_assistant_list_services(query_params)
|
||
return parse_result(result)
|
||
|
||
# pass
|
||
|
||
# 更新聊天助手
|
||
@ragflowController.post('/update_chat_assistant')
|
||
async def update_chat_assistant(
|
||
update_params: UpdateChatAssistantModel,
|
||
):
|
||
"""
|
||
更新聊天助手
|
||
"""
|
||
result = await RAGFlowService.update_chat_assistant_services(update_params)
|
||
return parse_result(result)
|
||
|
||
# 创建属于聊天助手的会话
|
||
@ragflowController.post('/create_session_with_chat')
|
||
async def create_session_with_chat(
|
||
create_params: CreateSessionWithChatModel,
|
||
):
|
||
"""
|
||
创建属于聊天助手的会话
|
||
"""
|
||
|
||
result = await RAGFlowService.create_session_with_chat_services(create_params)
|
||
return parse_result(result)
|
||
|
||
# 与聊天助手进行对话
|
||
@ragflowController.post('/converse_with_chat_assistant')
|
||
async def converse_with_chat_assistant(
|
||
request: Request,
|
||
converse_params: ConverseWithChatAssistantModel,
|
||
):
|
||
"""
|
||
与聊天助手进行对话
|
||
"""
|
||
|
||
start_time = time.perf_counter()
|
||
redis = getattr(request.app.state, 'redis', None)
|
||
cache_key = None
|
||
|
||
# 检查是否应该使用搜索服务 (Router Strategy)
|
||
# Using 'glm-4-flash' to classify intent: SEARCH vs RAG
|
||
t0 = time.perf_counter()
|
||
intent = await SearchService.classify_intent(converse_params.question)
|
||
logger.info(f"[RAG_PROFILE] Intent Router ({converse_params.chat_id}): {intent} | Time: {time.perf_counter() - t0:.3f}s | Query: {converse_params.question}")
|
||
|
||
if intent == 'SEARCH':
|
||
return await SearchService.handle_search_chat(converse_params, redis)
|
||
|
||
if not converse_params.stream and redis:
|
||
cache_key = build_chat_cache_key(converse_params.chat_id, converse_params.question)
|
||
cached = await redis.get(cache_key)
|
||
if cached:
|
||
logger.info('ragflow对话命中缓存: chat=%s', converse_params.chat_id)
|
||
return ResponseUtil.success(json.loads(cached))
|
||
|
||
# Revert to Native RAGFlow Service (OpenAI endpoint failed with Auth error)
|
||
t1 = time.perf_counter()
|
||
# 使用同步Generator,更简单
|
||
result = RAGFlowService.converse_with_chat_assistant_services(converse_params)
|
||
logger.info(f"[RAG_PROFILE] RAG Init/Connect ({converse_params.chat_id}): {time.perf_counter() - t1:.3f}s")
|
||
|
||
if converse_params.stream:
|
||
async def stream_response():
|
||
last_answer = ""
|
||
first_token_received = False
|
||
start_stream_time = time.perf_counter()
|
||
|
||
try:
|
||
# 将同步Generator转换为异步迭代
|
||
loop = asyncio.get_event_loop()
|
||
|
||
for chunk in result:
|
||
# 在事件循环中运行CPU密集型操作
|
||
await asyncio.sleep(0) # 让出控制权,允许其他协程运行
|
||
|
||
if not first_token_received:
|
||
first_token_received = True
|
||
latency = time.perf_counter() - start_stream_time
|
||
logger.info(f"[RAG_PROFILE] Time to First Token ({converse_params.chat_id}): {latency:.3f}s")
|
||
|
||
# Native RAGFlow returns cumulative text in 'answer' field
|
||
# Chunk structure: {'data': {'answer': 'Cumulative Text...'}}
|
||
payload = chunk.get('data') if isinstance(chunk, dict) else chunk
|
||
if not payload:
|
||
continue
|
||
|
||
body = payload if isinstance(payload, dict) else {'data': payload}
|
||
|
||
if isinstance(body, dict) and 'answer' in body:
|
||
# Inspect for errors even in stream
|
||
if isinstance(chunk, dict) and chunk.get('code') and chunk.get('code') != 0:
|
||
logger.error(f"RAGFlow Stream Error: {chunk}")
|
||
|
||
current_answer = body['answer']
|
||
# Calculate Delta
|
||
if current_answer.startswith(last_answer):
|
||
delta = current_answer[len(last_answer):]
|
||
if delta:
|
||
body['answer'] = delta # Send only the new part
|
||
last_answer = current_answer
|
||
yield format_sse(body)
|
||
else:
|
||
# Context reset (unlikely but safe fallback)
|
||
last_answer = current_answer
|
||
yield format_sse(body)
|
||
else:
|
||
yield format_sse(body)
|
||
|
||
yield format_sse({'status': 'completed'}, event='end')
|
||
except Exception as exc:
|
||
logger.exception('[RAG_PROFILE] ragflow流式对话异常: %s', exc)
|
||
yield format_sse({'message': str(exc)}, event='error')
|
||
finally:
|
||
total_time = time.perf_counter() - start_time
|
||
logger.info(f'[RAG_PROFILE] Total Stream Duration ({converse_params.chat_id}): {total_time:.3f}s')
|
||
|
||
return StreamingResponse(stream_response(), media_type='text/event-stream')
|
||
|
||
response = parse_result(result)
|
||
if redis and cache_key and isinstance(result, dict) and result.get('code') == 0:
|
||
await redis.set(cache_key, json.dumps(result.get('data'), ensure_ascii=False), ex=60)
|
||
logger.info(f'[RAG_PROFILE] Total Sync Duration ({converse_params.chat_id}): {time.perf_counter() - start_time:.3f}s')
|
||
return response
|
||
|
||
|
||
|
||
# return parse_result(result)
|
||
|
||
# 获取用户权限
|
||
@ragflowController.get('/get_user_permission', dependencies=[Depends(CheckRoleInterfaceAuth('pad'))])
|
||
async def get_user_permission(current_user = Depends(LoginService.get_current_user)):
|
||
"""
|
||
获取用户权限
|
||
"""
|
||
|
||
user_auth_list = current_user.permissions
|
||
print(user_auth_list)
|
||
|
||
|
||
return ResponseUtil.success(data=user_auth_list)
|
||
|
||
def parse_result(result):
|
||
code = result.get('code', 0)
|
||
if code != 0:
|
||
msg = result.get('message') or result.get('msg') or '接口异常'
|
||
return ResponseUtil.error(msg=msg, data=result.get('data', None))
|
||
return ResponseUtil.success(data=result.get('data', None))
|
||
|
||
|
||
def build_chat_cache_key(chat_id: str, question: str) -> str:
|
||
digest = hashlib.sha256(question.encode('utf-8')).hexdigest()
|
||
return f'ragflow:chat:{chat_id}:{digest}'
|
||
|
||
|
||
def format_sse(data: dict, event: str | None = None) -> str:
|
||
payload = json.dumps(data, ensure_ascii=False)
|
||
prefix = f'event: {event}\n' if event else ''
|
||
return f'{prefix}data: {payload}\n\n'
|
||
|
||
|
||
async def aiter_sync_generator(sync_gen):
|
||
"""将同步 Generator 转换为异步迭代器"""
|
||
loop = asyncio.get_event_loop()
|
||
try:
|
||
while True:
|
||
# 在线程池中获取下一个值
|
||
value = await loop.run_in_executor(None, next, sync_gen, None)
|
||
if value is None:
|
||
break
|
||
yield value
|
||
except StopIteration:
|
||
pass
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from typing import AsyncIterator, Union
|
||
|
||
import aiofiles
|
||
import redis
|
||
from fastapi import Depends, Request, Response
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi_controller.decorators import Controller
|
||
from fastapi_controller.dependencies import LoginService
|
||
from fastapi_controller.response import ResponseUtil
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from module_admin.service.ragflow_service import RAGFlowService
|
||
from module_admin.service.search_service import SearchService
|
||
from module_admin.service.login_service import CheckRoleInterfaceAuth
|
||
from module_admin.model.ragflow_models import (
|
||
ConverseWithChatAssistantModel,
|
||
CreateSessionWithChatModel,
|
||
DeleteFileModel,
|
||
DocumentIdsModel,
|
||
RagflowListQueryModel,
|
||
UpdateChatAssistantModel,
|
||
)
|
||
|
||
import logger as logger_module |