From d21060c67020bc49df3e06a1f0addd03b64f4975 Mon Sep 17 00:00:00 2001 From: haotian <2421912570@qq.com> Date: Tue, 25 Feb 2025 09:38:13 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9--=E5=B0=86=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E9=9B=86=E6=88=90=E5=88=B0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/model_api.py | 2 +- doc/安装文档.md | 22 +++++++++++ example_model_manager.py | 18 +-------- .../__pycache__/model_manager.cpython-39.pyc | Bin 19306 -> 19581 bytes function/model_manager.py | 37 +++++++++++++++--- 5 files changed, 57 insertions(+), 22 deletions(-) diff --git a/api/model_api.py b/api/model_api.py index 7e3837c..f3f4f7a 100644 --- a/api/model_api.py +++ b/api/model_api.py @@ -51,7 +51,7 @@ async def train_model(request: TrainRequest): """模型训练""" result = model_manager.train_model( train_data=request.dataset['train'], - val_data=request.dataset.get('val'), + val_data=request.dataset['val'], model_config={ 'model_name': request.model, 'parameters': request.parameters, diff --git a/doc/安装文档.md b/doc/安装文档.md index 6d66e13..3cd1652 100644 --- a/doc/安装文档.md +++ b/doc/安装文档.md @@ -26,6 +26,28 @@ GET http://10.0.0.202:8992/data/feature/method/{method_name} ### 3.5 处理数据集 POST http://10.0.0.202:8992/data/process +传递参数 + { + "input_path": "dataset/dataset_raw/breast_cancer.csv", + "output_dir": "dataset/dataset_processed", + "process_methods": [ + { + "method_name": "IsolationForest", + "params": { + "contamination": 0.1, + "random_state": 42 + } + } + ], + "feature_methods": [ + + ], + "split_params": { + "test_size": 0.1, + "val_size": 0.2 + } + } + ### 3.6 获取可用数据集列表 GET http://10.0.0.202:8992/data/datasets diff --git a/example_model_manager.py b/example_model_manager.py index afdaba2..1b1a071 100644 --- a/example_model_manager.py +++ b/example_model_manager.py @@ -71,15 +71,7 @@ print("--------------------------------------------获取所有已训练模型 e print("--------------------------------------------模型训练---------------------------------------------------") -# 加载数据 -train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv') -val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv') -# 准备特征和标签 -X_train = train_data.drop('target', axis=1) -y_train = train_data['target'] -X_val = val_data.drop('target', axis=1) -y_val = val_data['target'] # 模型配置 model_config = { @@ -98,14 +90,8 @@ model_config = { # 模型文件 直接在 mlruns/文件夹下 for i in range(3, 4): result = manager.train_model( - { - 'features': X_train, - 'labels': y_train - }, - { - 'features': X_val, - 'labels': y_val - }, + '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/train_breast_cancer_20250224_170615.csv', + '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/val_breast_cancer_20250224_170615.csv', model_config, f'breast_cancer_classification_{i}' ) diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index c96ae0d2313e6a9e68cf9e2bf6aa807bbbccbb1b..de479fb1c966cc81ccb5df58a2c3af07d9ffa4ab 100644 GIT binary patch delta 5159 zcmZu#32)D`{oCs(stj z)6>)4)6+Af3-^;RzE1+O-|rFNQ&RtY&&zi_9XMdMdR1@!khPf*iJi8+O6KQpwXe4L zCATJbtI6BguWx%Q9|$-6P0 zP&dmdC81L4L@{eBeUs$VRDENbCS_e$spOs_X)TCI9$8TeV?QAf??NOmM`&jM^)4!F z0aA2w&GZ|{Uc;{N&8u9CB@u*wM}QR#pUKNGb>_!d`Y?)x7P7 z59GJ2Jyg}SG}wVX834n-b5l!_L9rvSr9+id+O>&({-qlB=m-M{{p|C~=51)7 zh_vkVohX-HX6I=wb&$;V^n2l}^g0 z)Ob&CSDI>6VNZo>iOOCNHL@dBmq(`EplG*LGC3p$Mnf{PdnzkRr@6`apx#ea7OVtz;GJs8zj^ld7>aRWzDo&)0rfiy10f4a?%C2H9tjff=J;uG`gd z587}q0z()F73$n2dD1H-V={*>X!$W%^h^a0M< zrw$qR0o*(EHXL(1faKxcM&qgOG{s$QxFP@3G@8_UW<=}7$Sbu|(-|f46l8^>Mj=wF zt}3)Cf4uQg3zW#MbLuLOp!TN_ib@1KIA=>^4w>IWxCMbApc(XT0MSM7WEbYtkyFey zcX#bW$gG!AQ{W0(6QMbPhI-Ask2e0RmsS6G%MH4 zjK|=jr1O~WbJL#1@u*aC!5mSp^EGU`r!+~fNs=el*Zn-%G)nGZtqa#WFh3!M4lYdE z7xPtp2^lXVpxeU^m_ut_GVEBwv(?;;haWTdFds$8+Og)v8}tl;*ujX)ZL$oNgli`L`@HVjlENgKu@qY zTISSXyg>zcL2Iy|?{_CsZ7i@XWS1{mIObyX#+*ugzq07frg?K`5jLAme2yo1O5kWzm33`o0lL=Bg1Zn z1k;CNTqL$_dDRvyMv{}S8!P};UCD!KhCx)R;l-YKOpm8i8bp|;yJStWcgcFJS8}HN z^-RAGZO)fie)*EnpHY3$HpAuzLjKd`E(_Vk$cm~FtO|x#QMDM2XHc=>NUC~oTA|lL zPrnb~%+ZI>Jb&-m`=2|TzmM*o7&hF2G=l2*mgMVn$x_7k24;omFawPf&|L^d z6W|9PoJh2&z!t>Y(_Vkvz`@IOM1W!G^FFRTNpocCXEC=C1`89L?st!N>CiF9-0HprbgV2_R+%nEcEF z!UNy+@YSBR2*S8!xQo?ZxogDT=b7+k$uUARx^KddRL+(1DR#x7I3Kj^7Ip}_DPsYd zBU}ckoSVzICS|g&zJTIZJc{?QCFjYy6n9w2d9&^!t}pB6T)!?%lpGVr$=$=bzz8&g ziZ5$bd|}}jIL_OCX#0!W0cZ!zHc%yDVWRYy&?ej%y1|kw$pvvtwj>+G@#AEFMOK(7 zD`+ISQm!$WEmeX=8cX$Z;FN|(gfTKTZ5d~lg$q1Wg;KuHGEV3rrNW$Asr-fj6JQHX zCb_a~S+@`u%--YVE|M$H3Vn_VrxG${jtUP!#E6hBpQy}wvR<4I=up;ojEv;a+N!KS ziwe#_RR|OtaXG7Vlf`6arLSJ}dR<>N>oWf+)*+8GOuV!*GC2Y2{mEh;B z2W`6ri=fol1e-8sB?2yMGbDi~?^y~dhkw_@cejvgw+VXfxMk=GBIK%xFuYW$g{V}I z5mcND^U3P6Vf06_XT#?DJ7Ba@ldC>hy`%8ER9Fhs4neHT_g0}&i%QP{kRNt?_AN@{3v1y*iFK?GTj{?9c3^!YFLt8MEXVxj!Z+oE zFs5@jWR|DNyS!PRTO(qkXc$l5NP2(}H z$M6*zrXpS%Pr}&Z`Up+KYg)~xuvwwr;hR|~w>do|U5TxA>n40X!+Qtv(<#Ht4R4s6uzNGY)d1(A6vMrRrejgLKSp6w3A7zMw;uy;3KJ|bd2 zWUxT%5E>8~5$X|muJWwKgf*=3luq%4qR(N|)|E~t=r6GS8-%k6ZvsfZxE7TM_!7a6PWZ`kzzfipoGC?yf65GT(~FXWJGKBvQ~}I_ucot8S0^e)1gj0# zlUP3EUE(}ChplSAjGSOMv@ff@0+rhOv~-HThokwfJ&zPb__6j~z;$$5{2Yboi>c0?&?Fe1$Q-6DTHwIB{FonSPa`zHb}0`-&S_YzD`9E z8bHLc@2cKjn!x!8ZVg_&^V01}#W3`^4%ab?3mJApRb^=shbe?Kg3kU^-Et)l0RNR* zf@+45p>zsYrU7`cf9ItGiwiTja1$i8Q|z&DElIOi!;S20bt_Y=D@P+G^L#PT6e3V)U!?6J6t=*W(x=P`s#!LC4bMG{ z@O6YE2nvAQXoU~n5nZ?f_F=f)7qL$zh9mZ5Vwiney&5#A>PBy!bDj_%2`WjrTiqN< zMN`W7k}K|6_7-W&m(^^tlTB=(R`YW=)nXZ{8uTc8w)VqXj8xgqZEHI?8RKMf$?jj& z9T+%)r#^u2Aj16sGwxJI)sn_KIY945;&Fs~5bj0z4nU;Bq>7%_(}}2ID%6acdPK@Kwf6K);b2Mx3KDlF7hb5tYOvu(?H;MI+5nCafK== zgFYlEW7;%#PGSyGhI{S;i1>uHQ9T{c(5=Ykh0D*-Xv#<|@zx6sXKELY<0X0;0!7uL z;3>`2RN9t5)p*(lrLw=NuIjre{b_`vQo#;2?P)xU#5)j-DlFUWXD-UqaSJ<(|dqbF90iwMUr$49;Igz+)pN z+SIs4;g-|pqQmKnxL)k2yIPL9MC7&Q*R;MxAhwCz?m|Q9J#4+)Paa{1_AXeun!|Xv70&M1{;Zaq3-=6!FXtMPPl|&L`I#CU>)*wk=_rFVf3tY$2A@-#4nU+C85!` zfNL<~9H*KZrM)nb34I6tSh0ukC_~g9>0C=LWk2ea$?w>oI@`k-dxe&F+{E9jocw}82wLwk03nm8lR3&q+<~uHVskBfJSN!UpoYo4y$}X{P#rXpFx*L%bVB1!V?+Tb+A;?Bq?9oY z8VZAQ+~WXtd{LU=)JF1W*L!T_dPX)>&tg^JepNGKR3Asd+?CSIk&H^Curd(< zc=!GX-+lJZ_l`gHUjA-+!#wq$hwuBx<4+f;qHNi$vUM6bX3BS9L2ZC^Y!HSTJq@CW z(|^J0pzPxg%s|>W1XC*-@Se_FibQT;RLCyN(Rcv89syMp8d=F!WGMS+!(n(P$2ZsDMZV1xKq8&Z)s)GW-G%&;6D_tCkMK zHXafW9=u4jsK6Ha0(r{dCh}98UMGP?YpMv_zd08wzB?ejy?NOU;$-WlJVlzG@1gjT zk|NFlV2@HC4;$^#-!usMKF3VI9fXak5j0p@=yVc>(` z?eJ?1+9c^1nfZ*>UUFbon)9B7*&R3MJf<|~Q=M76>a3AY+K43PE||ND=5Cm~tvQT5 zHPW2_q|_^2S9z^1=gs+WPu83D;r@ejS!pg%kVkTUA-|EWrVxg#!ZFKwQ!#YxNYlo#cX9Y1B36R!c0N`=tkL zCkc_V!MTzwQA=y2>?~2s3f2VYO0za3o+P;t2rh#etYz^4^ng8Ak}WA%8wP7jb7f#r zdA2lLhKe5}2X8f{(TcgsY)~ynO5yN|tY_3cS~Y*TT4~L*C17^d>_i%p_DP)}%=n~3BI4gQfJ0LHGKXY_kJ+Yo5h+(rq?r!28oX3~grMMDg*dD_TY>he z%d!<}%?aC(uD5p9s(JXFL-&gBBK;o183gN%!;yGO7Lq>_f^XgEV%HS$TJ1Kp z0}(HN?}~J%kK@i50@g>n7B97?nq3N;1ZRJ8F92*|a{gtlieTBNP>_4kj?P zlC3DD6_a%}f{Y-N@F^6`Mt9WDQt@#4#PLz_D*COl)m}dgj1hk!i}CXk90ZgS%n9ZEn1$Wn#F<|83nh=fQvaoByf zfUT62+DcuOBxDbfQc_7O;cvhmvX|S#q|y~44&q{``v&|aXyv_lR9pUwz8INpK*Hrf z;GP@t+Y!2a=#<;gzSkW{$Ka&P5$X|I z5H2F?0iisAGt;UtWn2Z9d0hYtBLD?&pCCd{IxJ( Dict: @@ -789,6 +789,11 @@ class ModelManager: 训练结果字典 """ try: + + + + + # 检查实验是否存在且被删除 experiment = mlflow.get_experiment_by_name(experiment_name) if experiment and experiment.lifecycle_stage == 'deleted': @@ -800,6 +805,28 @@ class ModelManager: # 设置MLflow实验 mlflow.set_experiment(experiment_name) + + if os.path.exists(train_path): + # 加载数据 + train_data = pd.read_csv(train_path) + else: + return { + 'status': 'error', + 'message': '找不到训练集路径' + } + if os.path.exists(val_path): + val_data = pd.read_csv(val_path) + else: + return{ + 'status': 'error', + 'message': '找不到验证集路径' + } + + # 准备特征和标签 + X_train = train_data.drop('target', axis=1) + y_train = train_data['target'] + X_val = val_data.drop('target', axis=1) + y_val = val_data['target'] with mlflow.start_run() as run: # 记录基本信息 @@ -833,15 +860,15 @@ class ModelManager: # 训练模型 self.logger.info(f"Starting training {model_config['algorithm']}") - model.fit(train_data['features'], train_data['labels']) + model.fit(X_train, y_train) # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') # mlflow.log_param('end_time', timestamp) # 在验证集上评估 - val_predictions = model.predict(val_data['features']) + val_predictions = model.predict(X_val) metrics = self._calculate_metrics( - val_data['labels'], + y_val, val_predictions, model_config['task_type'] )