From 382271e424c6a74f41aa5c92743ec7c8eb3882af Mon Sep 17 00:00:00 2001 From: haotian <2421912570@qq.com> Date: Tue, 25 Feb 2025 09:49:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9--=E4=BF=AE=E6=94=B9=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=AE=AD=E7=BB=83=E6=8E=A5=E5=8F=A3=E7=9A=84=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/model_api.py | 20 ++++++++++++------ example_model_manager.py | 2 +- .../__pycache__/model_manager.cpython-39.pyc | Bin 19581 -> 19614 bytes function/model_manager.py | 4 +++- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/api/model_api.py b/api/model_api.py index f3f4f7a..6e0b801 100644 --- a/api/model_api.py +++ b/api/model_api.py @@ -8,10 +8,14 @@ model_manager = ModelManager() # 数据模型 class TrainRequest(BaseModel): - model: str - dataset: Dict[str, str] + + train_path: str + val_path: str + algorithm: str + task_type : str parameters: Dict metrics: List[str] + experiment_name: str class PredictRequest(BaseModel): run_id: str @@ -50,13 +54,15 @@ async def get_metrics(): async def train_model(request: TrainRequest): """模型训练""" result = model_manager.train_model( - train_data=request.dataset['train'], - val_data=request.dataset['val'], + train_path=request.train_path, + val_path=request.val_path, model_config={ - 'model_name': request.model, - 'parameters': request.parameters, + 'algorithm': request.algorithm, + 'task_type': request.task_type, + 'params': request.parameters, 'metrics': request.metrics - } + }, + experiment_name=request.experiment_name ) if result['status'] == 'error': raise HTTPException(status_code=500, detail=result['message']) diff --git a/example_model_manager.py b/example_model_manager.py index 1b1a071..cabf46a 100644 --- a/example_model_manager.py +++ b/example_model_manager.py @@ -77,7 +77,7 @@ print("--------------------------------------------模型训练----------------- model_config = { 'algorithm': 'XGBClassifier', 'task_type': 'classification', - 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629', + # 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629', 'params': { 'n_estimators': 100, 'learning_rate': 0.1, diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index de479fb1c966cc81ccb5df58a2c3af07d9ffa4ab..73647ac7a611f01d90b2e3d4e35970a7290c4ee2 100644 GIT binary patch delta 429 zcmex6gK^$WM&3kTUM>b8h*a2{{%a#IyFJGqCI*HHOjTKvW$Z69HcftKuUXHykcp9@ zgkga|3V#g~h+V^+A|T1IKyV?03q!0}3{x$CtpG^AhQCGtBtL<%P^N}qfe=s~k0irF z##%wJZowMC6hR;#W|~lmFi=EDlA#%4gD_Z37$ycX!A_DPMI@VP0%OsX6j6xPDPlGZ zKnq0R76{b{r7>he9kztAkFi#`h9Qe>0eg!0hGu3)hH!>p218)b6emp1a}?*^ z#>h}B!N^c$kRm&If}^^)OszznL@-c2h&Es_0J4G^H03s*a};EfiQ-L3EJ-X*Er~BF yO3ciQ;)Zd{5_2{`an504te70=X3oOH$iu$5%k2ddW83Cp&*O}YtedU9ofrXq3U0Rm delta 404 zcmXw#y-UMD7{=fC5?d*z%|~l2?Ir$Rpt7MT84xWu`Lg}`q@-)3&jesV z+)ckydq>y*u2mI@Who2^0+(3MpRP}J^Be+Qs4?NAr#TR)OywF=lXJ1w978x7g&szC zV^5Dg{ohLn%-AIZ?3}TD#HYXtb0GMUO87e0w_qoJox&aVG2{O6ujw6Fq&vy5V1RBZ6llxE8C3BFr8}T#rUM{QE|+SoX{lvn7bpY=I?KEvQa}pAC(Sbo%0t gLup};;E{jOoF!^1d}w|VJPpqm-w>s6eR&@L00^jIZvX%Q diff --git a/function/model_manager.py b/function/model_manager.py index 22873e6..b062962 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -833,7 +833,9 @@ class ModelManager: mlflow.log_param('algorithm', model_config['algorithm']) mlflow.log_param('task_type', model_config['task_type']) # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 - mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 + mlflow.log_param('dataset_train', train_path) # 直接写数据集路径 + mlflow.log_param('dataset_val', val_path) + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # mlflow.log_param('start_time', timestamp)