完成--完成获取模型列表方法--时间仍然是初始时区的时间
This commit is contained in:
parent
04ce95b59f
commit
d867a3004d
34
example_experiment_list.py
Normal file
34
example_experiment_list.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from function.model_manager import ModelManager
|
||||||
|
|
||||||
|
# 创建模型管理器实例
|
||||||
|
manager = ModelManager()
|
||||||
|
|
||||||
|
# 获取实验列表
|
||||||
|
result = manager.get_experiments(
|
||||||
|
page=2,
|
||||||
|
page_size=10,
|
||||||
|
include_deleted=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
print("\nMLFlow实验列表:")
|
||||||
|
print(f"状态: {result['status']}")
|
||||||
|
if result['status'] == 'success':
|
||||||
|
print(f"\n总数: {result['total_count']}")
|
||||||
|
print(f"当前页: {result['page']}")
|
||||||
|
print(f"每页数量: {result['page_size']}")
|
||||||
|
print("\n实验列表:")
|
||||||
|
for exp in result['experiments']:
|
||||||
|
print(f"\n实验ID: {exp['experiment_id']}")
|
||||||
|
print(f"名称: {exp['name']}")
|
||||||
|
print(f"存储位置: {exp['artifact_location']}")
|
||||||
|
print(f"状态: {exp['lifecycle_stage']}")
|
||||||
|
print(f"创建时间: {exp['creation_time']}")
|
||||||
|
print(f"最后更新: {exp['last_update_time']}")
|
||||||
|
print(f"运行次数: {exp['runs_count']}")
|
||||||
|
if exp['tags']:
|
||||||
|
print("标签:")
|
||||||
|
for tag_name, tag_value in exp['tags'].items():
|
||||||
|
print(f" {tag_name}: {tag_value}")
|
||||||
|
else:
|
||||||
|
print(f"错误信息: {result['message']}")
|
||||||
@ -94,6 +94,10 @@ class ModelManager:
|
|||||||
elif key.startswith('metrics.'):
|
elif key.startswith('metrics.'):
|
||||||
metrics[key.replace('metrics.', '')] = value
|
metrics[key.replace('metrics.', '')] = value
|
||||||
|
|
||||||
|
# 转换时间为本地时间
|
||||||
|
start_time = pd.to_datetime(run['start_time']).tz_convert('Asia/Shanghai')
|
||||||
|
end_time = pd.to_datetime(run['end_time']).tz_convert('Asia/Shanghai')
|
||||||
|
|
||||||
# 构建模型信息
|
# 构建模型信息
|
||||||
model_info = {
|
model_info = {
|
||||||
'run_id': run['run_id'],
|
'run_id': run['run_id'],
|
||||||
@ -101,8 +105,8 @@ class ModelManager:
|
|||||||
'algorithm': params['algorithm'], # 从配置或其他地方获取
|
'algorithm': params['algorithm'], # 从配置或其他地方获取
|
||||||
'task_type': params['task_type'], # 从配置或其他地方获取
|
'task_type': params['task_type'], # 从配置或其他地方获取
|
||||||
'dataset': params['dataset'], # 从配置或其他地方获取
|
'dataset': params['dataset'], # 从配置或其他地方获取
|
||||||
'training_start_time': run['start_time'],
|
'training_start_time': start_time.strftime('%Y-%m-%d %H:%M:%S'), # 格式化为本地时间字符串
|
||||||
'training_end_time': run['end_time'],
|
'training_end_time': end_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'parameters': {
|
'parameters': {
|
||||||
k: v for k, v in params.items()
|
k: v for k, v in params.items()
|
||||||
@ -129,6 +133,83 @@ class ModelManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error getting finished models: {str(e)}"
|
error_msg = f"Error getting finished models: {str(e)}"
|
||||||
self.logger.error(error_msg)
|
self.logger.error(error_msg)
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': error_msg
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_experiments(
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 10,
|
||||||
|
include_deleted: bool = False
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
获取MLflow中保存的实验列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
include_deleted: 是否包含已删除的实验
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
实验列表信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取所有实验
|
||||||
|
experiments = mlflow.search_experiments()
|
||||||
|
|
||||||
|
# 过滤已删除的实验
|
||||||
|
if not include_deleted:
|
||||||
|
experiments = [exp for exp in experiments if exp.lifecycle_stage == 'active']
|
||||||
|
|
||||||
|
# 计算分页
|
||||||
|
total_count = len(experiments)
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = start_idx + page_size
|
||||||
|
page_experiments = experiments[start_idx:end_idx]
|
||||||
|
|
||||||
|
# 格式化实验信息
|
||||||
|
experiment_list = []
|
||||||
|
for exp in page_experiments:
|
||||||
|
# 获取实验的运行次数
|
||||||
|
runs = mlflow.search_runs(
|
||||||
|
experiment_ids=[exp.experiment_id],
|
||||||
|
filter_string="status = 'FINISHED'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取最后更新时间并转换为本地时间
|
||||||
|
if len(runs) > 0:
|
||||||
|
last_update_time = pd.to_datetime(runs['end_time'].max())
|
||||||
|
else:
|
||||||
|
last_update_time = pd.to_datetime(exp.creation_time)
|
||||||
|
|
||||||
|
# 创建时间转换为本地时间
|
||||||
|
creation_time = pd.to_datetime(exp.creation_time)
|
||||||
|
|
||||||
|
experiment_info = {
|
||||||
|
'experiment_id': exp.experiment_id,
|
||||||
|
'name': exp.name,
|
||||||
|
'artifact_location': exp.artifact_location,
|
||||||
|
'lifecycle_stage': exp.lifecycle_stage,
|
||||||
|
'creation_time': creation_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
'last_update_time': last_update_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
|
'tags': exp.tags,
|
||||||
|
'runs_count': len(runs)
|
||||||
|
}
|
||||||
|
experiment_list.append(experiment_info)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'experiments': experiment_list,
|
||||||
|
'total_count': total_count,
|
||||||
|
'page': page,
|
||||||
|
'page_size': page_size
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error getting experiments: {str(e)}"
|
||||||
|
self.logger.error(error_msg)
|
||||||
return {
|
return {
|
||||||
'status': 'error',
|
'status': 'error',
|
||||||
'message': error_msg
|
'message': error_msg
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user