import mlflow from mlflow.tracking import MlflowClient import pandas as pd from typing import Dict, List, Optional import logging from pathlib import Path import datetime import yaml import json class ModelManager: """模型管理类""" def __init__(self, config: Dict = None): """初始化模型管理器""" self.config = config or {} self.logger = logging.getLogger(__name__) self._setup_logging() # 初始化MLflow客户端 self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000') mlflow.set_tracking_uri(self.mlflow_uri) self.client = MlflowClient() def _setup_logging(self): """设置日志""" log_dir = Path('.log') log_dir.mkdir(exist_ok=True) file_handler = logging.FileHandler( log_dir / f'model_manager_{datetime.datetime.now():%Y%m%d_%H%M%S}.log' ) file_handler.setFormatter( logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') ) self.logger.addHandler(file_handler) self.logger.setLevel(logging.INFO) def get_finished_models( self, page: int = 1, page_size: int = 10, experiment_name: Optional[str] = None ) -> Dict: """ 获取已训练完成的模型列表 Args: page: 页码 page_size: 每页数量 experiment_name: 实验名称过滤 Returns: 模型列表信息 """ try: # 获取所有实验 if experiment_name: experiment = mlflow.get_experiment_by_name(experiment_name) if experiment is None: return { 'status': 'error', 'message': f'Experiment {experiment_name} not found' } experiments = [experiment] else: experiments = mlflow.search_experiments() # 获取所有运行记录 all_runs = [] for exp in experiments: runs = mlflow.search_runs( experiment_ids=[exp.experiment_id], filter_string="status = 'FINISHED'" ) all_runs.extend(runs.to_dict('records')) # 计算分页 total_count = len(all_runs) start_idx = (page - 1) * page_size end_idx = start_idx + page_size page_runs = all_runs[start_idx:end_idx] # 格式化模型信息 models = [] for run in page_runs: print(run.keys()) # 收集所有参数 params = {} metrics = {} for key, value in run.items(): if key.startswith('params.'): params[key.replace('params.', '')] = value elif key.startswith('metrics.'): metrics[key.replace('metrics.', '')] = value # 构建模型信息 model_info = { 'model_id': f"model_{pd.to_datetime(run['start_time']).strftime('%Y%m%d_%H%M%S')}", 'algorithm': params['algorithm'], # 从配置或其他地方获取 'task_type': params['task_type'], # 从配置或其他地方获取 'dataset': params['dataset'], # 从配置或其他地方获取 'training_start_time': run['start_time'], 'training_end_time': run['end_time'], 'metrics': metrics, 'parameters': { k: v for k, v in params.items() if k not in ['principle', 'advantages', 'disadvantages'] }, 'algorithm_info': { 'principle': params.get('principle', ''), 'advantages': params.get('advantages', ''), 'disadvantages': params.get('disadvantages', '') }, 'mlflow_run_id': run['run_id'], 'model_path': f"models/{run['run_id']}" } models.append(model_info) return { 'status': 'success', 'models': models, 'total_count': total_count, 'page': page, 'page_size': page_size } except Exception as e: error_msg = f"Error getting finished models: {str(e)}" self.logger.error(error_msg) return { 'status': 'error', 'message': error_msg }