import os class Config: """配置类""" # 数据库配置(使用 SQLite) SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db # Flask配置 FLASK_HOST = '0.0.0.0' FLASK_PORT = 5001 FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true' # 目录配置 MODEL_DIR = 'models' DATA_DIR = 'data' LOG_DIR = 'logs' UPLOAD_DIR = 'uploads' TEMPLATE_DIR = 'templates' # 文件上传配置 ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'} MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB # API配置 API_VERSION = 'v1' API_PREFIX = f'/api/{API_VERSION}' # 日志配置 LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' LOG_LEVEL = 'INFO' LOG_FILE = os.path.join(LOG_DIR, 'app.log') LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB LOG_BACKUP_COUNT = 5 # PyTorch配置 DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU BATCH_SIZE = 32 LEARNING_RATE = 0.001 NUM_EPOCHS = 100 # 模型训练配置 TRAIN_TEST_SPLIT = 0.2 RANDOM_SEED = 42 EARLY_STOPPING_PATIENCE = 10 MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints') # 缓存配置 CACHE_TYPE = 'simple' CACHE_DEFAULT_TIMEOUT = 300 # 安全配置 SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here') JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here') JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时 # 跨域配置 CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080'] # 数据验证配置 MAX_EQUIPMENT_NAME_LENGTH = 100 MAX_MANUFACTURER_NAME_LENGTH = 100 @classmethod def init_app(cls, app): """初始化应用配置""" # 创建必要的目录 for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR, cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]: os.makedirs(directory, exist_ok=True) # 配置日志 import logging from logging.handlers import RotatingFileHandler formatter = logging.Formatter(cls.LOG_FORMAT) file_handler = RotatingFileHandler( cls.LOG_FILE, maxBytes=cls.LOG_MAX_SIZE, backupCount=cls.LOG_BACKUP_COUNT ) file_handler.setFormatter(formatter) file_handler.setLevel(cls.LOG_LEVEL) app.logger.addHandler(file_handler) app.logger.setLevel(cls.LOG_LEVEL) # 配置上传目录 app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH # 配置跨域 from flask_cors import CORS CORS(app, resources={ r"/api/*": {"origins": cls.CORS_ORIGINS} }) return app # 创建配置实例 config = Config()