134 lines
4.8 KiB
Python
134 lines
4.8 KiB
Python
import mlflow
|
|
from mlflow.tracking import MlflowClient
|
|
import pandas as pd
|
|
from typing import Dict, List, Optional
|
|
import logging
|
|
from pathlib import Path
|
|
import datetime
|
|
import yaml
|
|
import json
|
|
|
|
class ModelManager:
|
|
"""模型管理类"""
|
|
|
|
def __init__(self, config: Dict = None):
|
|
"""初始化模型管理器"""
|
|
self.config = config or {}
|
|
self.logger = logging.getLogger(__name__)
|
|
self._setup_logging()
|
|
|
|
# 初始化MLflow客户端
|
|
self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000')
|
|
mlflow.set_tracking_uri(self.mlflow_uri)
|
|
self.client = MlflowClient()
|
|
|
|
def _setup_logging(self):
|
|
"""设置日志"""
|
|
log_dir = Path('.log')
|
|
log_dir.mkdir(exist_ok=True)
|
|
|
|
file_handler = logging.FileHandler(
|
|
log_dir / f'model_manager_{datetime.datetime.now():%Y%m%d_%H%M%S}.log'
|
|
)
|
|
file_handler.setFormatter(
|
|
logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
)
|
|
self.logger.addHandler(file_handler)
|
|
self.logger.setLevel(logging.INFO)
|
|
|
|
def get_finished_models(
|
|
self,
|
|
page: int = 1,
|
|
page_size: int = 10,
|
|
experiment_name: Optional[str] = None
|
|
) -> Dict:
|
|
"""
|
|
获取已训练完成的模型列表
|
|
|
|
Args:
|
|
page: 页码
|
|
page_size: 每页数量
|
|
experiment_name: 实验名称过滤
|
|
|
|
Returns:
|
|
模型列表信息
|
|
"""
|
|
try:
|
|
# 获取所有实验
|
|
if experiment_name:
|
|
experiment = mlflow.get_experiment_by_name(experiment_name)
|
|
if experiment is None:
|
|
return {
|
|
'status': 'error',
|
|
'message': f'Experiment {experiment_name} not found'
|
|
}
|
|
experiments = [experiment]
|
|
else:
|
|
experiments = mlflow.search_experiments()
|
|
|
|
# 获取所有运行记录
|
|
all_runs = []
|
|
for exp in experiments:
|
|
runs = mlflow.search_runs(
|
|
experiment_ids=[exp.experiment_id],
|
|
filter_string="status = 'FINISHED'"
|
|
)
|
|
all_runs.extend(runs.to_dict('records'))
|
|
|
|
# 计算分页
|
|
total_count = len(all_runs)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = start_idx + page_size
|
|
page_runs = all_runs[start_idx:end_idx]
|
|
|
|
# 格式化模型信息
|
|
models = []
|
|
for run in page_runs:
|
|
print(run.keys())
|
|
# 收集所有参数
|
|
params = {}
|
|
metrics = {}
|
|
for key, value in run.items():
|
|
if key.startswith('params.'):
|
|
params[key.replace('params.', '')] = value
|
|
elif key.startswith('metrics.'):
|
|
metrics[key.replace('metrics.', '')] = value
|
|
|
|
# 构建模型信息
|
|
model_info = {
|
|
'model_id': f"model_{pd.to_datetime(run['start_time']).strftime('%Y%m%d_%H%M%S')}",
|
|
'algorithm': params['algorithm'], # 从配置或其他地方获取
|
|
'task_type': params['task_type'], # 从配置或其他地方获取
|
|
'dataset': params['dataset'], # 从配置或其他地方获取
|
|
'training_start_time': run['start_time'],
|
|
'training_end_time': run['end_time'],
|
|
'metrics': metrics,
|
|
'parameters': {
|
|
k: v for k, v in params.items()
|
|
if k not in ['principle', 'advantages', 'disadvantages']
|
|
},
|
|
'algorithm_info': {
|
|
'principle': params.get('principle', ''),
|
|
'advantages': params.get('advantages', ''),
|
|
'disadvantages': params.get('disadvantages', '')
|
|
},
|
|
'mlflow_run_id': run['run_id'],
|
|
'model_path': f"models/{run['run_id']}"
|
|
}
|
|
models.append(model_info)
|
|
|
|
return {
|
|
'status': 'success',
|
|
'models': models,
|
|
'total_count': total_count,
|
|
'page': page,
|
|
'page_size': page_size
|
|
}
|
|
|
|
except Exception as e:
|
|
error_msg = f"Error getting finished models: {str(e)}"
|
|
self.logger.error(error_msg)
|
|
return {
|
|
'status': 'error',
|
|
'message': error_msg
|
|
} |