103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
import os
|
|
|
|
class Config:
|
|
"""配置类"""
|
|
# 数据库配置
|
|
MYSQL_HOST = os.getenv('MYSQL_HOST', 'localhost')
|
|
MYSQL_USER = os.getenv('MYSQL_USER', 'root')
|
|
MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '123456')
|
|
MYSQL_DB = os.getenv('MYSQL_DB', 'equipment_cost_db')
|
|
|
|
# Flask配置
|
|
FLASK_HOST = '0.0.0.0'
|
|
FLASK_PORT = 5001
|
|
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
|
|
|
|
# 目录配置
|
|
MODEL_DIR = 'models'
|
|
DATA_DIR = 'data'
|
|
LOG_DIR = 'logs'
|
|
UPLOAD_DIR = 'uploads'
|
|
TEMPLATE_DIR = 'templates'
|
|
|
|
# 文件上传配置
|
|
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
|
|
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
|
|
|
|
# API配置
|
|
API_VERSION = 'v1'
|
|
API_PREFIX = f'/api/{API_VERSION}'
|
|
|
|
# 日志配置
|
|
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
LOG_LEVEL = 'INFO'
|
|
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
|
|
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
|
LOG_BACKUP_COUNT = 5
|
|
|
|
# PyTorch配置
|
|
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
|
|
BATCH_SIZE = 32
|
|
LEARNING_RATE = 0.001
|
|
NUM_EPOCHS = 100
|
|
|
|
# 模型训练配置
|
|
TRAIN_TEST_SPLIT = 0.2
|
|
RANDOM_SEED = 42
|
|
EARLY_STOPPING_PATIENCE = 10
|
|
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
|
|
|
|
# 缓存配置
|
|
CACHE_TYPE = 'simple'
|
|
CACHE_DEFAULT_TIMEOUT = 300
|
|
|
|
# 安全配置
|
|
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
|
|
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
|
|
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
|
|
|
|
# 跨域配置
|
|
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
|
|
|
|
# 数据验证配置
|
|
MAX_EQUIPMENT_NAME_LENGTH = 100
|
|
MAX_MANUFACTURER_NAME_LENGTH = 100
|
|
|
|
@classmethod
|
|
def init_app(cls, app):
|
|
"""初始化应用配置"""
|
|
# 创建必要的目录
|
|
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
|
|
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
|
|
os.makedirs(directory, exist_ok=True)
|
|
|
|
# 配置日志
|
|
import logging
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
formatter = logging.Formatter(cls.LOG_FORMAT)
|
|
file_handler = RotatingFileHandler(
|
|
cls.LOG_FILE,
|
|
maxBytes=cls.LOG_MAX_SIZE,
|
|
backupCount=cls.LOG_BACKUP_COUNT
|
|
)
|
|
file_handler.setFormatter(formatter)
|
|
file_handler.setLevel(cls.LOG_LEVEL)
|
|
|
|
app.logger.addHandler(file_handler)
|
|
app.logger.setLevel(cls.LOG_LEVEL)
|
|
|
|
# 配置上传目录
|
|
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
|
|
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
|
|
|
|
# 配置跨域
|
|
from flask_cors import CORS
|
|
CORS(app, resources={
|
|
r"/api/*": {"origins": cls.CORS_ORIGINS}
|
|
})
|
|
|
|
return app
|
|
|
|
# 创建配置实例
|
|
config = Config() |