CostPrediction/config.py

103 lines
3.0 KiB
Python

import os
class Config:
"""配置类"""
# 数据库配置
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost')
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
# Flask配置
FLASK_HOST = '0.0.0.0'
FLASK_PORT = 5001
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
# 目录配置
MODEL_DIR = 'models'
DATA_DIR = 'data'
LOG_DIR = 'logs'
UPLOAD_DIR = 'uploads'
TEMPLATE_DIR = 'templates'
# 文件上传配置
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 日志配置
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_LEVEL = 'INFO'
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
LOG_BACKUP_COUNT = 5
# PyTorch配置
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
# 模型训练配置
TRAIN_TEST_SPLIT = 0.2
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 10
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
# 缓存配置
CACHE_TYPE = 'simple'
CACHE_DEFAULT_TIMEOUT = 300
# 安全配置
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
# 跨域配置
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
# 数据验证配置
MAX_EQUIPMENT_NAME_LENGTH = 100
MAX_MANUFACTURER_NAME_LENGTH = 100
@classmethod
def init_app(cls, app):
"""初始化应用配置"""
# 创建必要的目录
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
os.makedirs(directory, exist_ok=True)
# 配置日志
import logging
from logging.handlers import RotatingFileHandler
formatter = logging.Formatter(cls.LOG_FORMAT)
file_handler = RotatingFileHandler(
cls.LOG_FILE,
maxBytes=cls.LOG_MAX_SIZE,
backupCount=cls.LOG_BACKUP_COUNT
)
file_handler.setFormatter(formatter)
file_handler.setLevel(cls.LOG_LEVEL)
app.logger.addHandler(file_handler)
app.logger.setLevel(cls.LOG_LEVEL)
# 配置上传目录
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
# 配置跨域
from flask_cors import CORS
CORS(app, resources={
r"/api/*": {"origins": cls.CORS_ORIGINS}
})
return app
# 创建配置实例
config = Config()