MLPlatform/function/model_manager.py

134 lines
4.8 KiB
Python

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import json
class ModelManager:
"""模型管理类"""
def __init__(self, config: Dict = None):
"""初始化模型管理器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
# 初始化MLflow客户端
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
mlflow.set_tracking_uri(self.mlflow_uri)
self.client = MlflowClient()
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_manager_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def get_finished_models(
self,
page: int = 1,
page_size: int = 10,
experiment_name: Optional[str] = None
) -> Dict:
"""
获取已训练完成的模型列表
Args:
page: 页码
page_size: 每页数量
experiment_name: 实验名称过滤
Returns:
模型列表信息
"""
try:
# 获取所有实验
if experiment_name:
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
return {
'status': 'error',
'message': f'Experiment {experiment_name} not found'
}
experiments = [experiment]
else:
experiments = mlflow.search_experiments()
# 获取所有运行记录
all_runs = []
for exp in experiments:
runs = mlflow.search_runs(
experiment_ids=[exp.experiment_id],
filter_string="status = 'FINISHED'"
)
all_runs.extend(runs.to_dict('records'))
# 计算分页
total_count = len(all_runs)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_runs = all_runs[start_idx:end_idx]
# 格式化模型信息
models = []
for run in page_runs:
print(run.keys())
# 收集所有参数
params = {}
metrics = {}
for key, value in run.items():
if key.startswith('params.'):
params[key.replace('params.', '')] = value
elif key.startswith('metrics.'):
metrics[key.replace('metrics.', '')] = value
# 构建模型信息
model_info = {
'model_id': f"model_{pd.to_datetime(run['start_time']).strftime('%Y%m%d_%H%M%S')}",
'algorithm': params['algorithm'], # 从配置或其他地方获取
'task_type': params['task_type'], # 从配置或其他地方获取
'dataset': params['dataset'], # 从配置或其他地方获取
'training_start_time': run['start_time'],
'training_end_time': run['end_time'],
'metrics': metrics,
'parameters': {
k: v for k, v in params.items()
if k not in ['principle', 'advantages', 'disadvantages']
},
'algorithm_info': {
'principle': params.get('principle', ''),
'advantages': params.get('advantages', ''),
'disadvantages': params.get('disadvantages', '')
},
'mlflow_run_id': run['run_id'],
'model_path': f"models/{run['run_id']}"
}
models.append(model_info)
return {
'status': 'success',
'models': models,
'total_count': total_count,
'page': page,
'page_size': page_size
}
except Exception as e:
error_msg = f"Error getting finished models: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}