diff --git a/api/model_api.py b/api/model_api.py index 7e3837c..f3f4f7a 100644 --- a/api/model_api.py +++ b/api/model_api.py @@ -51,7 +51,7 @@ async def train_model(request: TrainRequest): """模型训练""" result = model_manager.train_model( train_data=request.dataset['train'], - val_data=request.dataset.get('val'), + val_data=request.dataset['val'], model_config={ 'model_name': request.model, 'parameters': request.parameters, diff --git a/doc/安装文档.md b/doc/安装文档.md index 6d66e13..3cd1652 100644 --- a/doc/安装文档.md +++ b/doc/安装文档.md @@ -26,6 +26,28 @@ GET http://10.0.0.202:8992/data/feature/method/{method_name} ### 3.5 处理数据集 POST http://10.0.0.202:8992/data/process +传递参数 + { + "input_path": "dataset/dataset_raw/breast_cancer.csv", + "output_dir": "dataset/dataset_processed", + "process_methods": [ + { + "method_name": "IsolationForest", + "params": { + "contamination": 0.1, + "random_state": 42 + } + } + ], + "feature_methods": [ + + ], + "split_params": { + "test_size": 0.1, + "val_size": 0.2 + } + } + ### 3.6 获取可用数据集列表 GET http://10.0.0.202:8992/data/datasets diff --git a/example_model_manager.py b/example_model_manager.py index afdaba2..1b1a071 100644 --- a/example_model_manager.py +++ b/example_model_manager.py @@ -71,15 +71,7 @@ print("--------------------------------------------获取所有已训练模型 e print("--------------------------------------------模型训练---------------------------------------------------") -# 加载数据 -train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv') -val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv') -# 准备特征和标签 -X_train = train_data.drop('target', axis=1) -y_train = train_data['target'] -X_val = val_data.drop('target', axis=1) -y_val = val_data['target'] # 模型配置 model_config = { @@ -98,14 +90,8 @@ model_config = { # 模型文件 直接在 mlruns/文件夹下 for i in range(3, 4): result = manager.train_model( - { - 'features': X_train, - 'labels': y_train - }, - { - 'features': X_val, - 'labels': y_val - }, + '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/train_breast_cancer_20250224_170615.csv', + '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/val_breast_cancer_20250224_170615.csv', model_config, f'breast_cancer_classification_{i}' ) diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index c96ae0d..de479fb 100644 Binary files a/function/__pycache__/model_manager.cpython-39.pyc and b/function/__pycache__/model_manager.cpython-39.pyc differ diff --git a/function/model_manager.py b/function/model_manager.py index 0b4fa58..22873e6 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -771,8 +771,8 @@ class ModelManager: } def train_model( self, - train_data: Dict, - val_data: Dict, + train_path: str, + val_path: str, model_config: Dict, experiment_name: str ) -> Dict: @@ -789,6 +789,11 @@ class ModelManager: 训练结果字典 """ try: + + + + + # 检查实验是否存在且被删除 experiment = mlflow.get_experiment_by_name(experiment_name) if experiment and experiment.lifecycle_stage == 'deleted': @@ -800,6 +805,28 @@ class ModelManager: # 设置MLflow实验 mlflow.set_experiment(experiment_name) + + if os.path.exists(train_path): + # 加载数据 + train_data = pd.read_csv(train_path) + else: + return { + 'status': 'error', + 'message': '找不到训练集路径' + } + if os.path.exists(val_path): + val_data = pd.read_csv(val_path) + else: + return{ + 'status': 'error', + 'message': '找不到验证集路径' + } + + # 准备特征和标签 + X_train = train_data.drop('target', axis=1) + y_train = train_data['target'] + X_val = val_data.drop('target', axis=1) + y_val = val_data['target'] with mlflow.start_run() as run: # 记录基本信息 @@ -833,15 +860,15 @@ class ModelManager: # 训练模型 self.logger.info(f"Starting training {model_config['algorithm']}") - model.fit(train_data['features'], train_data['labels']) + model.fit(X_train, y_train) # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # mlflow.log_param('end_time', timestamp) # 在验证集上评估 - val_predictions = model.predict(val_data['features']) + val_predictions = model.predict(X_val) metrics = self._calculate_metrics( - val_data['labels'], + y_val, val_predictions, model_config['task_type'] )