CostPrediction/config.py

100 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
class Config:
"""配置类"""
# 数据库配置(使用 SQLite
SQLITE_DB = os.getenv('SQLITE_DB', '') # 为空则使用默认路径 data/equipment_cost.db
# Flask配置
FLASK_HOST = '0.0.0.0'
FLASK_PORT = 5001
FLASK_DEBUG = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
# 目录配置
MODEL_DIR = 'models'
DATA_DIR = 'data'
LOG_DIR = 'logs'
UPLOAD_DIR = 'uploads'
TEMPLATE_DIR = 'templates'
# 文件上传配置
ALLOWED_EXTENSIONS = {'xlsx', 'xls', 'csv'}
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB
# API配置
API_VERSION = 'v1'
API_PREFIX = f'/api/{API_VERSION}'
# 日志配置
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
LOG_LEVEL = 'INFO'
LOG_FILE = os.path.join(LOG_DIR, 'app.log')
LOG_MAX_SIZE = 10 * 1024 * 1024 # 10MB
LOG_BACKUP_COUNT = 5
# PyTorch配置
DEVICE = 'cpu' # 或 'cuda' 如果要使用 GPU
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
# 模型训练配置
TRAIN_TEST_SPLIT = 0.2
RANDOM_SEED = 42
EARLY_STOPPING_PATIENCE = 10
MODEL_CHECKPOINT_DIR = os.path.join(MODEL_DIR, 'checkpoints')
# 缓存配置
CACHE_TYPE = 'simple'
CACHE_DEFAULT_TIMEOUT = 300
# 安全配置
SECRET_KEY = os.getenv('SECRET_KEY', 'your-secret-key-here')
JWT_SECRET_KEY = os.getenv('JWT_SECRET_KEY', 'your-jwt-secret-key-here')
JWT_ACCESS_TOKEN_EXPIRES = 3600 # 1小时
# 跨域配置
CORS_ORIGINS = ['http://localhost:8080', 'http://127.0.0.1:8080']
# 数据验证配置
MAX_EQUIPMENT_NAME_LENGTH = 100
MAX_MANUFACTURER_NAME_LENGTH = 100
@classmethod
def init_app(cls, app):
"""初始化应用配置"""
# 创建必要的目录
for directory in [cls.MODEL_DIR, cls.DATA_DIR, cls.LOG_DIR,
cls.UPLOAD_DIR, cls.MODEL_CHECKPOINT_DIR]:
os.makedirs(directory, exist_ok=True)
# 配置日志
import logging
from logging.handlers import RotatingFileHandler
formatter = logging.Formatter(cls.LOG_FORMAT)
file_handler = RotatingFileHandler(
cls.LOG_FILE,
maxBytes=cls.LOG_MAX_SIZE,
backupCount=cls.LOG_BACKUP_COUNT
)
file_handler.setFormatter(formatter)
file_handler.setLevel(cls.LOG_LEVEL)
app.logger.addHandler(file_handler)
app.logger.setLevel(cls.LOG_LEVEL)
# 配置上传目录
app.config['UPLOAD_FOLDER'] = cls.UPLOAD_DIR
app.config['MAX_CONTENT_LENGTH'] = cls.MAX_CONTENT_LENGTH
# 配置跨域
from flask_cors import CORS
CORS(app, resources={
r"/api/*": {"origins": cls.CORS_ORIGINS}
})
return app
# 创建配置实例
config = Config()