100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
import os
|
||
|
||
class Config:
|
||
"""配置类"""
|
||
# 数据库配置(使用 SQLite)
|
||
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
|
||
|
||
# Flask配置
|
||
FLASK_HOST = '0.0.0.0'
|
||
FLASK_PORT = 5001
|
||
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
|
||
|
||
# 目录配置
|
||
MODEL_DIR = 'models'
|
||
DATA_DIR = 'data'
|
||
LOG_DIR = 'logs'
|
||
UPLOAD_DIR = 'uploads'
|
||
TEMPLATE_DIR = 'templates'
|
||
|
||
# 文件上传配置
|
||
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
|
||
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
||
|
||
# API配置
|
||
API_VERSION = 'v1'
|
||
API_PREFIX = f'/api/{API_VERSION}'
|
||
|
||
# 日志配置
|
||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
LOG_LEVEL = 'INFO'
|
||
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
|
||
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||
LOG_BACKUP_COUNT = 5
|
||
|
||
# PyTorch配置
|
||
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
|
||
BATCH_SIZE = 32
|
||
LEARNING_RATE = 0.001
|
||
NUM_EPOCHS = 100
|
||
|
||
# 模型训练配置
|
||
TRAIN_TEST_SPLIT = 0.2
|
||
RANDOM_SEED = 42
|
||
EARLY_STOPPING_PATIENCE = 10
|
||
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
|
||
|
||
# 缓存配置
|
||
CACHE_TYPE = 'simple'
|
||
CACHE_DEFAULT_TIMEOUT = 300
|
||
|
||
# 安全配置
|
||
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
|
||
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
|
||
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
|
||
|
||
# 跨域配置
|
||
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
|
||
|
||
# 数据验证配置
|
||
MAX_EQUIPMENT_NAME_LENGTH = 100
|
||
MAX_MANUFACTURER_NAME_LENGTH = 100
|
||
|
||
@classmethod
|
||
def init_app(cls, app):
|
||
"""初始化应用配置"""
|
||
# 创建必要的目录
|
||
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
|
||
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
|
||
os.makedirs(directory, exist_ok=True)
|
||
|
||
# 配置日志
|
||
import logging
|
||
from logging.handlers import RotatingFileHandler
|
||
|
||
formatter = logging.Formatter(cls.LOG_FORMAT)
|
||
file_handler = RotatingFileHandler(
|
||
cls.LOG_FILE,
|
||
maxBytes=cls.LOG_MAX_SIZE,
|
||
backupCount=cls.LOG_BACKUP_COUNT
|
||
)
|
||
file_handler.setFormatter(formatter)
|
||
file_handler.setLevel(cls.LOG_LEVEL)
|
||
|
||
app.logger.addHandler(file_handler)
|
||
app.logger.setLevel(cls.LOG_LEVEL)
|
||
|
||
# 配置上传目录
|
||
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
|
||
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
|
||
|
||
# 配置跨域
|
||
from flask_cors import CORS
|
||
CORS(app, resources={
|
||
r"/api/*": {"origins": cls.CORS_ORIGINS}
|
||
})
|
||
|
||
return app
|
||
|
||
# 创建配置实例
|
||
config = Config() |