From d867a3004d35cb05d5f194574e54b63c679445cc Mon Sep 17 00:00:00 2001 From: haotian <2421912570@qq.com> Date: Wed, 19 Feb 2025 17:38:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90--=E5=AE=8C=E6=88=90=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=E6=96=B9=E6=B3=95?= =?UTF-8?q?--=E6=97=B6=E9=97=B4=E4=BB=8D=E7=84=B6=E6=98=AF=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E6=97=B6=E5=8C=BA=E7=9A=84=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example_experiment_list.py | 34 +++++++++++++++ function/model_manager.py | 85 +++++++++++++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 example_experiment_list.py diff --git a/example_experiment_list.py b/example_experiment_list.py new file mode 100644 index 0000000..14eaf7d --- /dev/null +++ b/example_experiment_list.py @@ -0,0 +1,34 @@ +from function.model_manager import ModelManager + +# 创建模型管理器实例 +manager = ModelManager() + +# 获取实验列表 +result = manager.get_experiments( + page=2, + page_size=10, + include_deleted=False +) + +# 打印结果 +print("\nMLFlow实验列表:") +print(f"状态: {result['status']}") +if result['status'] == 'success': + print(f"\n总数: {result['total_count']}") + print(f"当前页: {result['page']}") + print(f"每页数量: {result['page_size']}") + print("\n实验列表:") + for exp in result['experiments']: + print(f"\n实验ID: {exp['experiment_id']}") + print(f"名称: {exp['name']}") + print(f"存储位置: {exp['artifact_location']}") + print(f"状态: {exp['lifecycle_stage']}") + print(f"创建时间: {exp['creation_time']}") + print(f"最后更新: {exp['last_update_time']}") + print(f"运行次数: {exp['runs_count']}") + if exp['tags']: + print("标签:") + for tag_name, tag_value in exp['tags'].items(): + print(f" {tag_name}: {tag_value}") +else: + print(f"错误信息: {result['message']}") \ No newline at end of file diff --git a/function/model_manager.py b/function/model_manager.py index c1ec09d..d9e1afd 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -94,6 +94,10 @@ class ModelManager: elif key.startswith('metrics.'): metrics[key.replace('metrics.', '')] = value + # 转换时间为本地时间 + start_time = pd.to_datetime(run['start_time']).tz_convert('Asia/Shanghai') + end_time = pd.to_datetime(run['end_time']).tz_convert('Asia/Shanghai') + # 构建模型信息 model_info = { 'run_id': run['run_id'], @@ -101,8 +105,8 @@ class ModelManager: 'algorithm': params['algorithm'], # 从配置或其他地方获取 'task_type': params['task_type'], # 从配置或其他地方获取 'dataset': params['dataset'], # 从配置或其他地方获取 - 'training_start_time': run['start_time'], - 'training_end_time': run['end_time'], + 'training_start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'), # 格式化为本地时间字符串 + 'training_end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'), 'metrics': metrics, 'parameters': { k: v for k, v in params.items() @@ -129,6 +133,83 @@ class ModelManager: except Exception as e: error_msg = f"Error getting finished models: {str(e)}" self.logger.error(error_msg) + return { + 'status': 'error', + 'message': error_msg + } + + def get_experiments( + self, + page: int = 1, + page_size: int = 10, + include_deleted: bool = False + ) -> Dict: + """ + 获取MLflow中保存的实验列表 + + Args: + page: 页码 + page_size: 每页数量 + include_deleted: 是否包含已删除的实验 + + Returns: + 实验列表信息 + """ + try: + # 获取所有实验 + experiments = mlflow.search_experiments() + + # 过滤已删除的实验 + if not include_deleted: + experiments = [exp for exp in experiments if exp.lifecycle_stage == 'active'] + + # 计算分页 + total_count = len(experiments) + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + page_experiments = experiments[start_idx:end_idx] + + # 格式化实验信息 + experiment_list = [] + for exp in page_experiments: + # 获取实验的运行次数 + runs = mlflow.search_runs( + experiment_ids=[exp.experiment_id], + filter_string="status = 'FINISHED'" + ) + + # 获取最后更新时间并转换为本地时间 + if len(runs) > 0: + last_update_time = pd.to_datetime(runs['end_time'].max()) + else: + last_update_time = pd.to_datetime(exp.creation_time) + + # 创建时间转换为本地时间 + creation_time = pd.to_datetime(exp.creation_time) + + experiment_info = { + 'experiment_id': exp.experiment_id, + 'name': exp.name, + 'artifact_location': exp.artifact_location, + 'lifecycle_stage': exp.lifecycle_stage, + 'creation_time': creation_time.strftime('%Y-%m-%d %H:%M:%S'), + 'last_update_time': last_update_time.strftime('%Y-%m-%d %H:%M:%S'), + 'tags': exp.tags, + 'runs_count': len(runs) + } + experiment_list.append(experiment_info) + + return { + 'status': 'success', + 'experiments': experiment_list, + 'total_count': total_count, + 'page': page, + 'page_size': page_size + } + + except Exception as e: + error_msg = f"Error getting experiments: {str(e)}" + self.logger.error(error_msg) return { 'status': 'error', 'message': error_msg