205 lines
6.2 KiB
Python
205 lines
6.2 KiB
Python
"""
|
||
文件管理API路由
|
||
提供CAD文件列表和下载功能
|
||
"""
|
||
from fastapi import APIRouter, HTTPException
|
||
from fastapi.responses import FileResponse, StreamingResponse
|
||
from typing import List
|
||
from pathlib import Path
|
||
import os
|
||
import zipfile
|
||
import io
|
||
from app.config import settings, software_config
|
||
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
def get_cad_extensions() -> dict:
|
||
"""获取CAD文件扩展名配置"""
|
||
return software_config.get_file_extensions()
|
||
|
||
|
||
def get_cad_files_path() -> Path:
|
||
"""获取CAD文件存储路径"""
|
||
cad_path = software_config.get_cad_files_path()
|
||
path = Path(cad_path)
|
||
|
||
if not path.exists():
|
||
raise HTTPException(status_code=404, detail=f"CAD文件路径不存在: {cad_path}")
|
||
|
||
return path
|
||
|
||
|
||
def is_cad_file(filename: str) -> bool:
|
||
"""判断文件是否为CAD文件(包括版本号后缀)"""
|
||
filename_lower = filename.lower()
|
||
|
||
# 获取配置的扩展名
|
||
cad_extensions = get_cad_extensions()
|
||
|
||
# 检查所有CAD扩展名
|
||
for software, extensions in cad_extensions.items():
|
||
for ext in extensions:
|
||
# 检查标准扩展名
|
||
if filename_lower.endswith(ext):
|
||
return True
|
||
# 检查带版本号的扩展名(如 .prt.1, .prt.2 等)
|
||
if ext in filename_lower and '.' in filename_lower.split(ext)[-1]:
|
||
# 验证版本号部分是否为数字
|
||
version_part = filename_lower.split(ext)[-1]
|
||
if version_part.startswith('.') and version_part[1:].isdigit():
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def scan_cad_files(base_path: Path) -> List[dict]:
|
||
"""扫描目录下的所有CAD文件"""
|
||
cad_files = []
|
||
|
||
try:
|
||
for root, dirs, files in os.walk(base_path):
|
||
for file in files:
|
||
if is_cad_file(file):
|
||
file_path = Path(root) / file
|
||
relative_path = file_path.relative_to(base_path)
|
||
|
||
# 获取文件信息
|
||
stat = file_path.stat()
|
||
|
||
cad_files.append({
|
||
'filename': file,
|
||
'relative_path': str(relative_path).replace('\\', '/'),
|
||
'absolute_path': str(file_path),
|
||
'size': stat.st_size,
|
||
'modified_time': stat.st_mtime,
|
||
'extension': get_file_extension(file)
|
||
})
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"扫描文件失败: {str(e)}")
|
||
|
||
return cad_files
|
||
|
||
|
||
def get_file_extension(filename: str) -> str:
|
||
"""获取文件扩展名(包括版本号)"""
|
||
filename_lower = filename.lower()
|
||
|
||
# 获取配置的扩展名
|
||
cad_extensions = get_cad_extensions()
|
||
|
||
# 检查是否有版本号
|
||
for software, extensions in cad_extensions.items():
|
||
for ext in extensions:
|
||
if ext in filename_lower:
|
||
idx = filename_lower.find(ext)
|
||
return filename[idx:]
|
||
|
||
# 如果没有匹配,返回标准扩展名
|
||
return Path(filename).suffix
|
||
|
||
|
||
@router.get("/files/list")
|
||
async def list_cad_files():
|
||
"""
|
||
获取CAD文件列表
|
||
返回所有Creo、PDMS、Revit格式的文件
|
||
"""
|
||
base_path = get_cad_files_path()
|
||
files = scan_cad_files(base_path)
|
||
|
||
return {
|
||
'success': True,
|
||
'base_path': str(base_path),
|
||
'total_count': len(files),
|
||
'files': files
|
||
}
|
||
|
||
|
||
@router.get("/files/download/{file_path:path}")
|
||
async def download_file(file_path: str):
|
||
"""
|
||
下载单个文件
|
||
|
||
Args:
|
||
file_path: 文件的相对路径
|
||
"""
|
||
base_path = get_cad_files_path()
|
||
full_path = base_path / file_path
|
||
|
||
# 安全检查:确保文件在允许的目录内
|
||
try:
|
||
full_path = full_path.resolve()
|
||
base_path = base_path.resolve()
|
||
|
||
if not str(full_path).startswith(str(base_path)):
|
||
raise HTTPException(status_code=403, detail="访问被拒绝")
|
||
except Exception:
|
||
raise HTTPException(status_code=400, detail="无效的文件路径")
|
||
|
||
# 检查文件是否存在
|
||
if not full_path.exists():
|
||
raise HTTPException(status_code=404, detail="文件不存在")
|
||
|
||
if not full_path.is_file():
|
||
raise HTTPException(status_code=400, detail="不是有效的文件")
|
||
|
||
# 检查是否为CAD文件
|
||
if not is_cad_file(full_path.name):
|
||
raise HTTPException(status_code=403, detail="只能下载CAD文件")
|
||
|
||
return FileResponse(
|
||
path=str(full_path),
|
||
filename=full_path.name,
|
||
media_type='application/octet-stream'
|
||
)
|
||
|
||
|
||
@router.post("/files/download/batch")
|
||
async def download_batch_files(file_paths: List[str]):
|
||
"""
|
||
批量下载文件(打包为ZIP)
|
||
|
||
Args:
|
||
file_paths: 文件相对路径列表
|
||
"""
|
||
if not file_paths:
|
||
raise HTTPException(status_code=400, detail="文件列表不能为空")
|
||
|
||
base_path = get_cad_files_path()
|
||
|
||
# 创建内存中的ZIP文件
|
||
zip_buffer = io.BytesIO()
|
||
|
||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
||
for file_path in file_paths:
|
||
full_path = base_path / file_path
|
||
|
||
# 安全检查
|
||
try:
|
||
full_path = full_path.resolve()
|
||
base_path_resolved = base_path.resolve()
|
||
|
||
if not str(full_path).startswith(str(base_path_resolved)):
|
||
continue # 跳过不安全的路径
|
||
except Exception:
|
||
continue
|
||
|
||
# 检查文件是否存在且为CAD文件
|
||
if full_path.exists() and full_path.is_file() and is_cad_file(full_path.name):
|
||
# 使用相对路径作为ZIP内的路径
|
||
arcname = file_path.replace('\\', '/')
|
||
zip_file.write(str(full_path), arcname=arcname)
|
||
|
||
# 重置缓冲区位置
|
||
zip_buffer.seek(0)
|
||
|
||
return StreamingResponse(
|
||
zip_buffer,
|
||
media_type='application/zip',
|
||
headers={
|
||
'Content-Disposition': 'attachment; filename="cad_files.zip"'
|
||
}
|
||
)
|