diff --git a/api/model_api.py b/api/model_api.py index f3f4f7a..6e0b801 100644 --- a/api/model_api.py +++ b/api/model_api.py @@ -8,10 +8,14 @@ model_manager = ModelManager() # 数据模型 class TrainRequest(BaseModel): - model: str - dataset: Dict[str, str] + + train_path: str + val_path: str + algorithm: str + task_type : str parameters: Dict metrics: List[str] + experiment_name: str class PredictRequest(BaseModel): run_id: str @@ -50,13 +54,15 @@ async def get_metrics(): async def train_model(request: TrainRequest): """模型训练""" result = model_manager.train_model( - train_data=request.dataset['train'], - val_data=request.dataset['val'], + train_path=request.train_path, + val_path=request.val_path, model_config={ - 'model_name': request.model, - 'parameters': request.parameters, + 'algorithm': request.algorithm, + 'task_type': request.task_type, + 'params': request.parameters, 'metrics': request.metrics - } + }, + experiment_name=request.experiment_name ) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) diff --git a/example_model_manager.py b/example_model_manager.py index 1b1a071..cabf46a 100644 --- a/example_model_manager.py +++ b/example_model_manager.py @@ -77,7 +77,7 @@ print("--------------------------------------------模型训练----------------- model_config = { 'algorithm': 'XGBClassifier', 'task_type': 'classification', - 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629', + # 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629', 'params': { 'n_estimators': 100, 'learning_rate': 0.1, diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index de479fb..73647ac 100644 Binary files a/function/__pycache__/model_manager.cpython-39.pyc and b/function/__pycache__/model_manager.cpython-39.pyc differ diff --git a/function/model_manager.py b/function/model_manager.py index 22873e6..b062962 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -833,7 +833,9 @@ class ModelManager: mlflow.log_param('algorithm', model_config['algorithm']) mlflow.log_param('task_type', model_config['task_type']) # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 - mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 + mlflow.log_param('dataset_train', train_path) # 直接写数据集路径 + mlflow.log_param('dataset_val', val_path) + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # mlflow.log_param('start_time', timestamp)