MLPlatform/function/method_reader_model.py

135 lines
5.3 KiB
Python

import yaml
from typing import Dict, List
import os
import logging
from pathlib import Path
class MethodReader:
"""方法配置读取器"""
def __init__(self):
"""初始化方法读取器"""
self.logger = logging.getLogger(__name__)
self.method_config = self._load_model_config()
self.parameter_config = self._load_parameter_config()
def _load_model_config(self) -> Dict:
"""加载方法配置文件"""
try:
config_path = Path('model/model.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Method config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded method config")
return config
except Exception as e:
self.logger.error(f"Error loading method config: {str(e)}")
raise
def _load_parameter_config(self) -> Dict:
"""加载参数配置文件"""
try:
config_path = Path('model/parameter.yaml')
if not config_path.exists():
raise FileNotFoundError(f"Parameter config file not found at {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self.logger.info("Successfully loaded parameter config")
return config
except Exception as e:
self.logger.error(f"Error loading parameter config: {str(e)}")
raise
def get_models(self) -> Dict:
"""获取预处理方法列表"""
try:
models = []
# 分类方法
classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys())
if classification_algorithms:
models.append({
"name": "classification_algorithms",
"description": "分类方法",
"method": classification_algorithms
})
# 回归方法
regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys())
if regression_algorithms:
models.append({
"name": "regression_algorithms",
"description": "回归方法",
"method": regression_algorithms
})
# 聚类方法
clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys())
if clustering_algorithms:
models.append({
"name": "clustering_algorithms",
"description": "聚类方法",
"method": clustering_algorithms
})
return {
"status": "success",
"models": models
}
except Exception as e:
self.logger.error(f"Error getting preprocessing methods: {str(e)}")
return {
"status": "error",
"error": str(e)
}
def get_model_details(self, method_name: str) -> Dict:
"""获取指定方法的详细信息"""
try:
# 在各个方法类别中查找方法原理和优缺点
method_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.method_config.get(category, {}):
method_info = self.method_config[category][method_name]
break
if method_info is None:
raise ValueError(f"Method {method_name} not found in method config")
# 查找方法参数信息
parameter_info = None
for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']:
if method_name in self.parameter_config.get(category, {}):
parameter_info = self.parameter_config[category][method_name]
break
if parameter_info is None:
raise ValueError(f"Method {method_name} not found in parameter config")
# 组合返回信息
return {
"status": "success",
"method": {
"name": method_name,
"description": parameter_info.get('description', ''),
"principle": method_info.get('principle', ''),
"advantages": method_info.get('advantages', []),
"disadvantages": method_info.get('disadvantages', []),
"applicable_scenarios": method_info.get('applicable_scenarios', []),
"parameters": parameter_info.get('parameters', [])
}
}
except Exception as e:
self.logger.error(f"Error getting method details: {str(e)}")
return {
"status": "error",
"error": str(e)
}