MLPlatform/function/model_manager.py

216 lines
7.8 KiB
Python

import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from typing import Dict, List, Optional
import logging
from pathlib import Path
import datetime
import yaml
import json
class ModelManager:
"""模型管理类"""
def __init__(self, config: Dict = None):
"""初始化模型管理器"""
self.config = config or {}
self.logger = logging.getLogger(__name__)
self._setup_logging()
# 初始化MLflow客户端
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
mlflow.set_tracking_uri(self.mlflow_uri)
self.client = MlflowClient()
def _setup_logging(self):
"""设置日志"""
log_dir = Path('.log')
log_dir.mkdir(exist_ok=True)
file_handler = logging.FileHandler(
log_dir / f'model_manager_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
)
file_handler.setFormatter(
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
)
self.logger.addHandler(file_handler)
self.logger.setLevel(logging.INFO)
def get_finished_models(
self,
page: int = 1,
page_size: int = 10,
experiment_name: Optional[str] = None
) -> Dict:
"""
获取已训练完成的模型列表
Args:
page: 页码
page_size: 每页数量
experiment_name: 实验名称过滤
Returns:
模型列表信息
"""
try:
# 获取所有实验
if experiment_name:
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
return {
'status': 'error',
'message': f'Experiment {experiment_name} not found'
}
experiments = [experiment]
else:
experiments = mlflow.search_experiments()
# 获取所有运行记录
all_runs = []
for exp in experiments:
runs = mlflow.search_runs(
experiment_ids=[exp.experiment_id],
filter_string="status = 'FINISHED'"
)
all_runs.extend(runs.to_dict('records'))
# 计算分页
total_count = len(all_runs)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_runs = all_runs[start_idx:end_idx]
# 格式化模型信息
models = []
for run in page_runs:
print(run.keys())
# 收集所有参数
params = {}
metrics = {}
for key, value in run.items():
if key.startswith('params.'):
params[key.replace('params.', '')] = value
elif key.startswith('metrics.'):
metrics[key.replace('metrics.', '')] = value
# 转换时间为本地时间
start_time = pd.to_datetime(run['start_time']).tz_convert('Asia/Shanghai')
end_time = pd.to_datetime(run['end_time']).tz_convert('Asia/Shanghai')
# 构建模型信息
model_info = {
'run_id': run['run_id'],
'experiment_id': run['experiment_id'],
'algorithm': params['algorithm'], # 从配置或其他地方获取
'task_type': params['task_type'], # 从配置或其他地方获取
'dataset': params['dataset'], # 从配置或其他地方获取
'training_start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'), # 格式化为本地时间字符串
'training_end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'),
'metrics': metrics,
'parameters': {
k: v for k, v in params.items()
if k not in ['principle', 'advantages', 'disadvantages']
},
'algorithm_info': {
'principle': params.get('principle', ''),
'advantages': params.get('advantages', ''),
'disadvantages': params.get('disadvantages', '')
},
'mlflow_run_id': run['run_id'],
'model_path': f"models/{run['run_id']}"
}
models.append(model_info)
return {
'status': 'success',
'models': models,
'total_count': total_count,
'page': page,
'page_size': page_size
}
except Exception as e:
error_msg = f"Error getting finished models: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}
def get_experiments(
self,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False
) -> Dict:
"""
获取MLflow中保存的实验列表
Args:
page: 页码
page_size: 每页数量
include_deleted: 是否包含已删除的实验
Returns:
实验列表信息
"""
try:
# 获取所有实验
experiments = mlflow.search_experiments()
# 过滤已删除的实验
if not include_deleted:
experiments = [exp for exp in experiments if exp.lifecycle_stage == 'active']
# 计算分页
total_count = len(experiments)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
page_experiments = experiments[start_idx:end_idx]
# 格式化实验信息
experiment_list = []
for exp in page_experiments:
# 获取实验的运行次数
runs = mlflow.search_runs(
experiment_ids=[exp.experiment_id],
filter_string="status = 'FINISHED'"
)
# 获取最后更新时间并转换为本地时间
if len(runs) > 0:
last_update_time = pd.to_datetime(runs['end_time'].max())
else:
last_update_time = pd.to_datetime(exp.creation_time)
# 创建时间转换为本地时间
creation_time = pd.to_datetime(exp.creation_time)
experiment_info = {
'experiment_id': exp.experiment_id,
'name': exp.name,
'artifact_location': exp.artifact_location,
'lifecycle_stage': exp.lifecycle_stage,
'creation_time': creation_time.strftime('%Y-%m-%d %H:%M:%S'),
'last_update_time': last_update_time.strftime('%Y-%m-%d %H:%M:%S'),
'tags': exp.tags,
'runs_count': len(runs)
}
experiment_list.append(experiment_info)
return {
'status': 'success',
'experiments': experiment_list,
'total_count': total_count,
'page': page,
'page_size': page_size
}
except Exception as e:
error_msg = f"Error getting experiments: {str(e)}"
self.logger.error(error_msg)
return {
'status': 'error',
'message': error_msg
}