From c2b2e20a4a03ab5e11a2d899f4961eed6664edb6 Mon Sep 17 00:00:00 2001 From: haotian <2421912570@qq.com> Date: Fri, 21 Feb 2025 11:34:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9--=E5=B0=86=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E6=89=80=E6=9C=89=E7=9A=84=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E9=83=BD=E9=9B=86=E6=88=90=E5=88=B0=E4=B8=80=E4=B8=AA=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.yaml | 39 +- doc/接口文档code.md | 75 ++- .../__pycache__/model_manager.cpython-39.pyc | Bin 9831 -> 18940 bytes function/method_reader_metric.py | 79 ---- function/method_reader_model.py | 135 ------ function/model_manager.py | 437 +++++++++++++++++- function/model_trainer.py | 298 ------------ 7 files changed, 547 insertions(+), 516 deletions(-) delete mode 100644 function/method_reader_metric.py delete mode 100644 function/method_reader_model.py delete mode 100644 function/model_trainer.py diff --git a/config/config.yaml b/config/config.yaml index dca3239..9d11c9b 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,2 +1,37 @@ -mlfow: - uri: "http://localhost:5000" \ No newline at end of file +# 服务器配置 +host: "10.0.0.202" +port: 8992 +workers: 4 +debug: true + +# MLflow配置 +mlflow_uri: "http://10.0.0.202:5000" + +# 数据处理配置 +dataset: + raw_dir: "dataset/dataset_raw" + processed_dir: "dataset/dataset_processed" + +# 模型配置 +model: + save_dir: "models" + batch_size: 32 + num_workers: 4 + +# 系统监控配置 +monitor: + log_dir: ".log" + resource_check_interval: 60 # 秒 + cleanup_interval: 86400 # 24小时 + max_log_days: 30 # 日志保留天数 + +# 安全配置 +security: + secret_key: "your-secret-key" + token_expire_minutes: 1440 # 24小时 + +# 性能配置 +performance: + max_concurrent_trains: 4 # 最大并发训练数 + cache_size: 1024 # MB + timeout: 3600 # 秒 \ No newline at end of file diff --git a/doc/接口文档code.md b/doc/接口文档code.md index e874777..6124707 100644 --- a/doc/接口文档code.md +++ b/doc/接口文档code.md @@ -523,6 +523,9 @@ Error Response: } ``` +### 2.9 模型优化 -- 未实现 + + ## 3. 系统监控 ### 3.1 获取资源使用情况 ```http @@ -681,7 +684,7 @@ Error Response: } ``` -### 3.3 获取系统中训练状态 +### 3.3 获取系统中训练状态 ---- 未完成, 等开发系统后台时再实现. ```http GET /api/train/status/{task_id} @@ -764,6 +767,76 @@ Error Response: } ``` +## 4. 系统后台整体实现 + +### 4.1 系统架构 +``` +MLPlatform/ +├── api/ # API接口层 +│ ├── __init__.py +│ ├── data_api.py # 数据处理相关接口 +│ ├── model_api.py # 模型相关接口 +│ └── system_api.py # 系统监控相关接口 +├── function/ # 功能实现层 +│ ├── data_processor.py # 数据处理类 +│ ├── model_manager.py # 模型管理类 +│ ├── model_trainer.py # 模型训练类 +│ ├── system_monitor.py # 系统监控类 +│ └── utils/ # 工具函数 +├── config/ # 配置文件 +│ └── config.yaml # 系统配置 +├── dataset/ # 数据集 +│ ├── dataset_raw/ # 原始数据 +│ └── dataset_processed/ # 处理后数据 +├── .log/ # 日志文件 +├── doc/ # 文档 +└── main.py # 主程序入口 +``` + +### 4.2 技术栈 +- FastAPI: Web框架 +- MLflow: 模型管理和实验跟踪 +- PyTorch/Scikit-learn: 机器学习框架 +- Pydantic: 数据验证 +- Uvicorn: ASGI服务器 + +### 4.3 主要功能 +1. 异步任务处理 + - 支持多个模型同时训练 + - 后台任务状态监控 + - 任务队列管理 + +2. 实时监控 + - 系统资源监控 + - 训练进度监控 + - 日志实时查看 + +3. 错误处理 + - 全局异常处理 + - 错误日志记录 + - 优雅降级策略 + +4. 安全性 + - API认证授权 + - 请求限流 + - 参数验证 + +### 4.4 性能优化 +1. 数据处理 + - 数据流式处理 + - 缓存机制 + - 批量处理 + +2. 模型训练 + - GPU利用优化 + - 分布式训练支持 + - 模型检查点 + +3. 系统监控 + - 性能指标采集 + - 资源使用预警 + - 自动清理机制 + ## 附录A:方法详细说明 ### A1. 数据预处理方法 diff --git a/function/__pycache__/model_manager.cpython-39.pyc b/function/__pycache__/model_manager.cpython-39.pyc index 23d59e7e2931cf7f8afabff626bc5d1e6566107a..30dc1d68bdf2f014d943de8e671adccd3697fc34 100644 GIT binary patch literal 18940 zcmdUX36LDsnPz5I)_rt!^&z#+5~z_1q?W)K+dVR4KnSp56)l@5p!M|Ni6q|Lj6C8B_50zkdGO%pJ!S<-b#=_b-6V zC-Aj?ii9alttbZlY8ADr8Ja4~0)~#VKqXiW8KG*}2v;LU#H))MQPk;`ST%0MMOmXy7o*4aSD*Mq^_&Yh=AKHW{1HN4RoD^-AN)>SklJs@$fq2#X$9 zSoD}?j5Bq=l8ar$8kMP>>L&J8rYrUPZ>f~bnw<-{`tDNEcJ;j_%XTAo&f2AVtx!Qb z{jP$25Wf*tuuZ#EHFHTfRVWr4yilCWTg5sz-GNzd7E2ax@d^^i7b+Dm7n#1+yGrtU zF<)pDz3focEY$La16IA#u+6;5d7Zl%QD_}*6u8N}vIxJuwJdtm^HIxcP5)D)~mA>(?+3XdCk&RsdBL1Fm232EGehC zvE2o`u(w`dCU=wfm^G`;sl+nvBJE-CAJNw*@U=dvDvGHXDpL%NsYZZ(#MD_}US&E9 z9uFEp(8Ggjkftjw?pkD~+=RN+Vt_$x_Hg*#JwUP0Wm=d~jZ488(FSgc)bU z=xKzFB9~+zHBP3#KXKg6yy zhsz`F(FKM7gl(>AY@A)iv}qmf=##>C&|Ky9zGXhbwz3@hS?$gBYIY5J8AB_bZ9}iy z*#z=yyxJXXCvs~|U3{p`wevxC9s3YU*O}|f8^o8&ug45O%x*yWMy7#UCqbokW>)qU zYK0k-HwLiBOW$5teBsOI&n%oj^~LkwIY*=%D%NY$rJ16J>dqfs>-gS+uT??fsI#Gx z;wX8=R@<5#SWt1TJ4!papr|NshrHkL0=0^uB?PE$vfrg%Sr;6 z79XRThUBV@G^Ucb!5$7_+O$Sa4+Bq0vKr0REweH$R`-&^x1#RS$9LeInmY=tTB>d5 z^}4;|V4-f~bnn==_pVC8p04xij_F3NNNj0GRm`+1iDF`Qj`pKeE7|${dUR$DBT=-N znouM7#Bdc0q>*cm^b*rV=N=*;j?Lq>B2KUW7sM5vYe7IST}BT+A;MZ)ebl+3(!eF8F*q%OSAQM(9ssuR^UGHb{*7sETAg2^tFmLwG$*q zpU(=?D~H-)7A7Kc&~`ss+L5`_s?yR`rMVdd7HUUX#Ev=oF)MZ=L(&OBu0&J3I=Rm@8B( zm$Lf=O=KlaWT#6NGh3_M+39+t##F=xuwhSc|L98)24<~TXL!6*%}x6o#UeQA zbfZ$4%MzWLEGr4j>#Vu{R$}a;4!87IPG(<<^ANVz4Ve!>_E_%V9*jLzxA)M<;!anu z&zdz?r->R-t1xW}>g4O_Hoa7vuDj7&A1InaXt_ZtMgo@2UkdUq)Fv)hm7j$ub5%)t zu4)R}lSvNE_(mk#@wG;gD4HhdC#9v;ty)42X(@G7ZLaC3q26=VM?R!f(6R*id>$8F z;wgiFu)s;;8Csi&Mki8PP+0IeovF_U)+!(x&C%K+)Q4GUK^3H-lq2mZ5fsRTu4A@V zj#G|>@yH{M1O~Zq<_&tx&%C(!<}*!w0^Hyp@w8}X^YueUs_*FJFD1u6J6;`U`SIJw z_l@swUO&FIVBw5G$XnU%+3~H!WxUIZdBm*rcOf{eB|>*02X>xd1`Sji`Sq9;Fd0Fy`H3M9^|C4e9E4OzpWd7Z6X< zAa0n)C`N)PG=I<=ftX0ogCsU^BfcIhh^ProSBKx%1G*YP3RbYT|1tJ4&%WnKbNAW6 z8Ze0FUv!`3|1UOx5h6W|<{>be=&!+O^vf8{t1B>?7|{$QVc{q_)fGa1+Y*;%jY2vW%ewNH}@S1r!;Ahyr`r*iV;|z4Qp?KHB^T zEuy4`Bt+EW6Avx@?yJ2bsyQra$k#Rm9W}4KtD}$nYHGNQ&Ni<^9e&4NN6lVk1oX}XgW{0Wz>M8U$2xPPzFkHt4q0TbHhrxds1{5lmg#EzIyZB@ zY$-S1V@0_*3SIljjdd-jT#`pHQ68m)cqWfiLaX2@N=SX@X-Wnu$xt#x32AI@T%JUs zvF22H0GFMo4QS`!Ymtnh^W=YeuYcYg>)-9(jx`vsm6V{ildFexEND3=yLm&!+J+6WP#@cq!%(B4<^{4 z)n0h~wZ$i%K7ZyJXk>4_xv=!$*_Q(DSg}&DtkQI;SRm~-?@!NiN4Pn|A!nBt7b|4w za#F9mt(&o%It$A-SNAAyrCL^;rYvbnQ;?)tOobCoxv_f-m4+!u-;=O>Kb1!?7ZbZQ zw~W@^p*+<6?x0Y+r|Xa74vU_=qD9nTj}bm)?jsYD^ne74TiX8I>W0>8Q8s}7HnzsZbsFPlTTml#-(Y1^^@`To z>Xof^)y=K-)$!H_QJ$$>Ro&9s2rW-@0*>wkosg3(Z*?NhfRifcoR|}L5>D7jJA+Qf z8FGf55oZ)Fu69;CV`y=Wv&LDA&pKzlv(eeWhUPc6vLM~;5%rYHnZ1n-cXbH|HDkm% zMuc%K3Dh%JzjKZg7lM71*;}8123acJ?b9OwlED6ZZ*k*y*34UofcM|K4=qP`dqB`V z+?3@YY?J47cl2)PYV|5qKGU+r9iqlva~m%(!k})h*D*wGrl&4*6Dx?~p0a^EfEJ09 zBD=z%vD*p_9Fjt9>SpAIKQ(2RW)2>xbE|tB%9Ai9y9I2+>2$_-Po;IF7{%|w9yAP;conl^sd{xwax*njr1B5@Pw-(MTAQ~4~FC#m~x#bK)8RFRU4 zzYD9SFSn@1rsN1E4J6H>PuC8?45?*(t~r^_Z8wIkLlr#n+C)V>`MltkV(mt}Q)L~g zijCl(Lr1V?v($w^3CRjNIw_sk;X#nWLmh$!Cz{Hc6RIjK51*Vn4fdwQRq{ z)^w@I2w=TNpfYpd|7CG2twl`TNcwa^u-S6imlm>55J$d9EbnQo#3JHVpuiKtv|3Lf zmAF$PAalc1M%TKptf;zBTx5#UK(~gj{Mz2~zYJ}O!CYvrzg)ET%;+L9?2Dv(fsyI6 zu+O0bBF$T+CSZDKde9oGV75J~k4_8l{s=ZA1x?7{6kxkkYV$`f zeJ;^T*%2p2SnV-k*R}>qN-K>LV73_f#^2IfgRP7+*p4$;zWgo7%mL&QEXLvs+Aihr zk40Sqbw5TOq$=tXojN<&PB|%-90MlP9$+c#z>rrq&{H;yHfgAVDK?07fMrBFB+_A| zL+vyh0pdLD{ifL{TXh^QxEf@uX;g7F=3lL$E4Fq%AV*;9e8;LEI;G9QQPmoSRlk#E?O|tB+VibdP+tN1JF5hWq&chP za|{u%sJ2Frj_&t4pv#+vkV)D@BFWcK31_s+K;ysSdd?qqi9>=S4}cH$#X^tD+=)RWo8 zGcR0x`NZO>ljpzn>V-ESSvq&Ruf3nSTY$dHrqVlgsiOPsr}!im-;CetpQ3Ct*449L z(%g_`7wm@RhCOiX(vU33emuKn&mB{D?7#ii-CJJLj8xB@OUyEo0J8{W2V@Ba$Oyw$ zt8-@Y7L6AccA_=mD$R&z$xOiPvI`t2I}UR*c@uP>9s7ZW%^WP02%BZ10ybZxmM<|k z**BaU^|_oIg+J?1-kzH^-7x%n>`lhliFdjd{me?utQ6=VhLV17Z~4>Le0hlx7;KvS>3E) zx6qihiJh)`$W@Ol=fAOC3@%EtK5x7he9j_xHKc|B={2wK)I^bWWPsn4TOi3@I~3n-{1}0GDn^o8 z76|~lhGq~_wf91L2H+uS83~la&xN)+^p6;fZ+rv8>ab4dLVgHl)ad$sJqS;j8hcj{ zhSbFW42>kZ`o%kuNE$#@f?$;QmYMG>Omt69@6w$+fU{j`Ff$M67QR9z5tUxGVh~Fr zr9`L|H{k3xmq?+2il93IYaNVX4Z=|04H4b~enNgOz|t`JJtZIn;ZI6H80K)u2@544 z0wo~oL?GfHZN-Fg5fjP<`QNlw9KFPkLDe{<^M~wUJLJS!cuZ--xeF1Vs3k;MtPIto z)q9p2^PCEmh_)SCn4nlaRh4@W6X_g{2N{HWq*1*vLh+pw5#4mEWbR&Sx^97O= zzx(!59LHPVf90(=pI?0C_aJ^f0r>YLewRIb2 ze^&nFN$&fZ6-9CHGM0(rK8XZ1wJVVb8Jr=Blp#ND8L>NuG?GVO@uZOvD!{>Y#3XyS zQkpi4b4B=(zysjQ()d%jo4-m)Kl13tc2xj?2-!1@NAPnr!Qjl%OZcIrtN}%0(oH)yd1|4nn?{rZHLgI$_cQ0_5k=$W0cxDS<*_lx~{9 zh`5Mb5x%+%;LZfL{hWBd-?@lPrFG2PbN$mi1h`&7T|fe$QuoOj|`S$CnKGN=h| zoa*7h5pZC^nL`P6SW5(mFUt>cU)22~6#JFjcHIY25ol7VA=hBwKD3nBsn1d`A=7-> z1E=9FChxIz6|q%9`Ur?-x@VzjWvvVp6&`0C5440mR}P8*gdjZ4nrE{S%pIY6u-7!$ zE&>o>y9v+X;m?S1;#4a<(TTRk!eX3=um`~cVV%Hh4Kx!r7T95cFie&}a8Gi+6X1|M zrOf>e*di$!w9>L_7=3id=(i@A9!eEgjz4iScX5Xd-?JQE$Whj=w?xW&i54iPUX zkB&D2mV^(I7?UyRiTQkoL0K;agnr78qp1;Qrd=pitj2npn%o2tvbFTgA1^-fmGe)1 z{o+$!yKwgGWOgc-=2ubgVLEHraQ;!`BtH>~s!k>e{~i@?Mq;GsTz7{Re1uRA`=Ak% zlM=lMC9%t1h=Ju|lCy{v{ZVHWBSKGE4Dv6j|ERwr=!+irm44loI7Gr5Qf$pUu?L#I z64MYYqD#ds6bZ!CG?)RDD=>u(q;PdSRln5Vjl_GDu700YZ{HWQ=+@PX4OVai#o5L^ z5EQf!H%e2NK`rrm!=7!}G7~#cu%X8b%qPUmBPHnFh{&W6r~E9hA1DYNR;pJMBLQUHVLk~)t0woQxaAm~I34RfEqSn>RIm>><4wMnY!l3+s--MngqzKATIaCg} zBNXlk&Ju6lWk=gUb4cYvJVT-!hYRh9P^@s)%Sm`)Vy#py3!WkCzAEYv5#6e>1AJV$VkbSr%{4elX+1@{=V zGwq=Tg=L=8!FdLOjl}Gc@@RS0F%7SQI z5F?!pLE#*NIy#K+2 zU)kR53}erSF`~v3A`iT%HB#G#^^Mb7yt;^el@nvMc8y0+w?)*A)O1=GP_;Eq3N?2* zVVPoe6D{&IDdny0+=6mML3!4`n)1$Cpk7I5_!vUt$RD)}b%(Qf{@1Y8&Txq&&pI)m zZ`j+M5j?5kLptvMhP~YxDNnGGF_;4AKigXGjIcCIKdd?Hr^(T}j*U6%Nlu6!zhLhu z@5J}o_H~3#Lf))732&jZuKXb<02FU6$Y2U!ETu^0etf|vhW1A5WJ3dbVt5-FC1V-RXiML*V z^ZZla_9c;IvKQWb4zlFpH;ylT@5_Em|Kk5j68U^!GJD~TCl?=n5ir@($uk#TJG=PC zaldPlSv_n5fSZ^@M>3J|u)-Z(UINwrOFvjVeH?+ei>JS}_-a24C;1E}fXLi4&ma(O z@$@6vty_p|ZOJ0eW(zpimR#RtdRbD(Zn*H~Q;W~Nu=LG`&OiRhvOS{tdx#2#IeYxA zAH9mLU1~qF3}~cw{x~aI7X1q+KfK>w>bw6OT#)|>C8YT_2}rXhcQivX^i<+Oe~KC- zSScJZEBCk|2r5XiJyMv8-CvZkEvLV8;g4R-X^7w!Pyk_VM$&@?WC$JqGxRiN079YI z^_!7_d;HJwBkcnM0do{fUaGoDY6H+kvT@)_G7ICtNC?h;pj4qqcm4@+6Ofo1GaqQI zLQbw(LbN?90{<7NYfy{^)PnYFOY?sK+JWW}6D1eE@C?NEWVRWYxNh22Euk=Y#)_+z zm{bvj@a#w+ipw+Tg#el-1ul@8VkE5s1&wqpl4sGo3l!8WHvBeFbqVF-{{>0zBd#uR zAkq)`zoLXR0wXj#M^V5=tQ#1`|25V6N&g=U|5N8C1AhZPuXhRG4^X z4PQ6xZJirFSg;T$$@zcAZTxOZz9~9@xnHo2=xzxAJrtfTY%w>8w;sqzZ5=EC&DH9b z8>to!;VE#7{~Nk5j7y=wg#f={Yb2N^R4y{j;O*T)V~}RS{{toekrF|N|B0^tf)aXg z67V7|$_)zThZ6ya=U{~95c%COjTk1M7iYtA)z1he$`C7qQY}DKXl$52Jw`;Y$GhZW z;Te;PVi`ffd5o}RS;&iWc}y&0pcB-CaDqk!VZ$L=0K-=n3oe{{v-z zP6`{@ugaPTymwMz3x$TEt^sqy2O)Vf6-Zo))HZ~-gx=HNT^CCS(ugXLAR<^`Z*&dx zETV;wW}yDvfD&MAQ2sL10tz6Zb!l}S*@f?ifY4>s4S`Xi6w*Dw=BWIAJ>;v0VtmZx z7y5h6VbEO{XYlr_4+05c9du%a=Jwxk(R>OrTeSg%kmklz3Rr3f$VW!DA-@hG?C|5k zK7Yje21MEj~)GipV;|IT0}7#+hjqc}d2M6Y=1U2!5}B$8#?%zW!Bt2fwdV zdf~xm{eA>1?C>l?i7$=!!zg%8pjZlxCp2CT02E#bn3p_WN4CjSF6!z6vB)APEFE#s zdteeD_7;pP7m_EH1{G@*(-yN_DGJq%k$EK!s&zmE)!&HwD0of>1rka@K4}_o!7#>6+ zT+LpDU5xTh0E!>NlqeAF@b!mbLI@Z1UPRjA4k&JLL5oN`z!`l#1;}ab^>hz<0y%p< zA%yJk(ZlzEOv#!3wBJ+E>nVhp03wm53qtWh!eqpVtJn;PUR=ffE4ZXlMt3CX3Wi4* zhKF#9gXnc&z*qN>J@9CjhO@hu><49@Wqh&+1}?~6{0g!s7wk}D6ZX&x)Djqnq*6jU z&VKX!xi3nLVT#{_X3g+E`C3PFZP!tWI6#jaJHf9T`YU~Jl~*sN{x8rUe}s~LznMD> zH%$f6i->u_%X5%oyi~XNb7+Y;YC3CN+}eK^f1fHuBkwHy9SWKo`p=N0rK|@jbkN9m zsOt#NQV#?Si0~4AC0z-#f92!y9_oz_bqetBvs9AuS1W=US3E(3K83@_A<&J-=^*XG z{WwSv5ES%XbP{mJ0(?reAvR6GU~ijXTo9#)=`d3h2%6I{}CnrK_lGz{|GAg8bbHzlHK{76`Lv_Ic@ zf6aThuFKNN7C85sv;{o&E=lff>JCy3U~bg%n+>tXl3y}2?PvF? zK6d{>xn=Br_2qOb2gn88nDBcG6wQr5Cze+UA#c3EcHx_qvLJ}{cK(KO;VUgum@0ye zeo1~oM+ZEB*HS>I!7_L~JUCN!psfQI3F6gD6+uNhxPHJ10SE{Ho|cM)?t}>#5U@!Y zenf(>(USlG5FT_&ed5s;g(k&>DicKbjsyqr_8iVZs3%Y70^UNZ$5jO2O1{oyuV6iuKN^{2w`Nc~6QZH)Hx7qgQni4mybj~N$bcwbkR$sSnZ5CKo4 zEGfrjsa;UZ;9>0apKub>;^iHh)oJx{6@gH4^#ihvzXM4?=&A@xLa4E~mSj5(u7E%# z%ru352a%(ICF-BnA7~9A{79~qB0AE5#1!&SI)gNCz)dIMX&q`0J4u!iyCcUQaw6q$ zd89Ko0*pl-fmU%gGLP4iG=+_JKq?9p6EM}!dSNQSd7p%&97V9w?cgdI;TBkhJ<51) zfg!QZRn8DSQGiH8c+0t?Ada>&M>9|m#joHh2wlppY2}F;^#Y2&z-3i0Vx{+x(N)g8l+fZ7f{AYCUx_i9n9cVZ5 zMHV*UdrTzH)+G}07@A9qcOF+$%~7eG_#=oX)lo_wLK8%#5KO__l$@i)#}uUdQ?Pb2 zxs4P;t$O^Ic&$|gbOC5GqV$fV&`S8Xs1JjZ-=M^yI|I`tTbQtTaT_5`#Rj6`&5l4K zWMsz60DX`l$+AZ=TWQ`4Es5k%ZdjTU@WSD}9ZK-7gd+L%85%4~C=8}noi&E=w0dSN z6?{Sg3>n@e3XHgzIQ};SMaIy!LFJO~uTjLTcml=%osJ%pUxw6p@BfHy+ypc~y!Gex z^PA{rMK5h!9#S#F&q&zOUMWL5?C%t;b6DfBa$;Hn@9?IzG_3nnAO$-d5iy+(H_weMoI7|Q+r*f1GcVm6}?f)q-)N}#JAz}$oXa!3S)@4=oLk^3on9zjSzOv42Pl_5-k z--a4~J0*7@xd=J|MCbh{5>62{{F7AqDJs`F-izaVDZh^rAQcLqqN_V8p$GsG#c}a{ z{PBN;EdEggInkEC_I8@eE+TP5jaun2PD6r9l9Z&77`k>##}a7|2grS}Q6zntJ-MtvQ{7me;07S2_y-k;SFMZ7qUTsWU+_zy67jtMIKYD&oI$jMjC zH&Jp0B^xLa98d5*VtZ}?@r9hu4gUdUb-d$H;Xk4Lzolf6l7EN9NR+I6;RsyNWRP+? zd4gN=cj*c)JTPTS1}UNE1GJVEI4{g`+PWUBP;6|D?jrnr4Zapff`AS!olXr#QmKtd zGwDd`)96-Jeks8_u#I;T2YXAvM&3#6OMhGg>%sLex*#8N#Ag0B@jJ>uYFqkasqIuE zK9>?R9ys%k7(aKFgun-6yyyg&kN5{0-aprv#oJW{OQPQ)0NW6Sa+Addh<&3K0A7Z> zyiz(KUfGx8=_4pN^s)Z@55kS?l7w|5o|KF}mVuDNu(>-gUQozQMBNpHEgt^ga8$T6 delta 2410 zcmZ{lO>h)N6o98^XLn~en@vK95c0DNB#{A0vH?MqpMWtW3PGgAqERNpc4wJnc4qCF zjY*6K760|1?E%YJ1$xng2a77Jtg=eISXx~0!eKmHmR0CQi}K*X_hv9bs?64Y+wb-3 z*WK@@^S30w9LZaajadc$etqwYO2dKA^JmnhPe9}fB}}0zBQ(LfxW*F8glQ~!AxV=g z#nKlPnqoT3z;Bv$a~(bzo^4maGPkU;MwSQ523Ck>a-eMnZ3|li+D4uQxpi4(YgyX` zjpo^1ylJ}mQp;sUEV6a(fXddhcBW2gV2B?bzB#_;mirCMI_qGa;HZ_UClsS7ze*e? z6H?Q*ZJJgtr5R!BWpz<8^&oR8yQoa-mkCjVhL;r3<)o*LlScWa_OLXQ!(_9(lswYE z9xm+XF$!vyo*rcaB=FRpFxFG;P-oO!JDZn}lqVKV&-AZiY1?LJy9* z06p@r%&7b#fBGm*R4oqhIaS3FUBl;Uj%B-?nX{I#Ew{{Rqs6A{egGilgxt}_y$+CLbQnUKuJhhOm)X7@6?9N^CT-&ExFu92m~xaSyR95D)5 zNLMj1{(zAa!>BxnID~i*F#@2ifq=`}ZpHKiO9WsxV#Ufw0Ry&DJbo zfxv}N^OjTbgdLo#%J;kPDeeJ>VJhAP**q{yxOz7 z^BAg<0<$43#Bs!}6m<7)T*82g`v4P0t0bV%&bNOTM-H%3bDWe=9a7Wq%?Z^cbr5i-)64N*} zgSaDtO_-`xYyVS4j`hu@hS7A_mGAnl6S7-=xMh0^Z!QkWYg>+xW3tQWicJp6VWT${ z1?_%0V+9@A$+2ia3rIq=09`IFJHv>K79 zh_i7phH=#t0s~ET@NY)$7E#QsIHCH#!~ef^E-tXRG;hLsAFromRW(O6S>HNIhUN9G z`%~Bq5QYug+LBwN{Li6c9)SxMQg8_W{(*(mKvd-MGqE5~4UQL{gZT-XtO`VZS1SK36+1BK;AS4`rUPvz z5oyF$IWTlG71ip?^39=LOK+ie5b*>cOiqFSfM!M@x1;#VMLiKF;nlD3FzKJOYCKdu zKh&$%3}?3R#S6GD4MP{qX)5r{vOgQw{}TGe7BLF9<^2`ZzKVEFF@Ue*@C`(iqWw6+ z+Mvle5m*!AytH?86tMJZI?4ltbp)j>)(r-Q3o1X}@nDIj;)j}M1^4St;G1kJB(``5 z-Tb#>#D}2km+N+ZQ!vPckq{-c83O!A7&X>P=Wzd0FV1!$x&dLrhIc)Q;v~XBop#pq z9PtdwFCs1?E&yoG_DyTnvK{NJ!-a-=-2Ju_tvrQd9Za;r*T=_FCz{%v)AKrMm27M{ z6U`cZqA0J9wP&F);2*N(i_`M=v7s)i&0AH6YSX^wig(ffM~Dv)%X07dNpexj@!sau NIK;ck@5gr#^&gd+S+4*9 diff --git a/function/method_reader_metric.py b/function/method_reader_metric.py deleted file mode 100644 index bf43fea..0000000 --- a/function/method_reader_metric.py +++ /dev/null @@ -1,79 +0,0 @@ -import yaml -from typing import Dict, List -import os -import logging -from pathlib import Path - -class MethodReader: - """方法配置读取器""" - - def __init__(self): - """初始化方法读取器""" - self.logger = logging.getLogger(__name__) - self.method_config = self._load_metrics() - - - def _load_metrics(self) -> Dict: - """加载方法配置文件""" - try: - config_path = Path('model/metrics.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Method config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded method config") - return config - - except Exception as e: - self.logger.error(f"Error loading method config: {str(e)}") - raise - - - - def get_metrics(self) -> Dict: - """获取预处理方法列表""" - try: - metrics = [] - - # 分类方法 - classification_metrics = self.method_config.get('classification', {}) - if classification_metrics: - metrics.append({ - "name": "classification_metrics", - "description": "分类方法评价指标", - "metric": classification_metrics - }) - - # 回归方法 - regression_metrics = self.method_config.get('regression', {}) - if regression_metrics: - metrics.append({ - "name": "regression_metrics", - "description": "回归方法评价指标", - "metric": regression_metrics - }) - - # 聚类方法 - clustering_metrics = self.method_config.get('clustering', {}) - if clustering_metrics: - metrics.append({ - "name": "clustering_metrics", - "description": "聚类方法评价指标", - "metric": clustering_metrics - }) - - return { - "status": "success", - "metric": metrics - } - - except Exception as e: - self.logger.error(f"Error getting preprocessing methods: {str(e)}") - return { - "status": "error", - "error": str(e) - } - - \ No newline at end of file diff --git a/function/method_reader_model.py b/function/method_reader_model.py deleted file mode 100644 index 2233e87..0000000 --- a/function/method_reader_model.py +++ /dev/null @@ -1,135 +0,0 @@ -import yaml -from typing import Dict, List -import os -import logging -from pathlib import Path - -class MethodReader: - """方法配置读取器""" - - def __init__(self): - """初始化方法读取器""" - self.logger = logging.getLogger(__name__) - self.method_config = self._load_model_config() - self.parameter_config = self._load_parameter_config() - - def _load_model_config(self) -> Dict: - """加载方法配置文件""" - try: - config_path = Path('model/model.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Method config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded method config") - return config - - except Exception as e: - self.logger.error(f"Error loading method config: {str(e)}") - raise - - def _load_parameter_config(self) -> Dict: - """加载参数配置文件""" - try: - config_path = Path('model/parameter.yaml') - if not config_path.exists(): - raise FileNotFoundError(f"Parameter config file not found at {config_path}") - - with open(config_path, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - - self.logger.info("Successfully loaded parameter config") - return config - except Exception as e: - self.logger.error(f"Error loading parameter config: {str(e)}") - raise - - def get_models(self) -> Dict: - """获取预处理方法列表""" - try: - models = [] - - # 分类方法 - classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys()) - if classification_algorithms: - models.append({ - "name": "classification_algorithms", - "description": "分类方法", - "method": classification_algorithms - }) - - # 回归方法 - regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys()) - if regression_algorithms: - models.append({ - "name": "regression_algorithms", - "description": "回归方法", - "method": regression_algorithms - }) - - # 聚类方法 - clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys()) - if clustering_algorithms: - models.append({ - "name": "clustering_algorithms", - "description": "聚类方法", - "method": clustering_algorithms - }) - - return { - "status": "success", - "models": models - } - - except Exception as e: - self.logger.error(f"Error getting preprocessing methods: {str(e)}") - return { - "status": "error", - "error": str(e) - } - - def get_model_details(self, method_name: str) -> Dict: - """获取指定方法的详细信息""" - try: - # 在各个方法类别中查找方法原理和优缺点 - method_info = None - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if method_name in self.method_config.get(category, {}): - method_info = self.method_config[category][method_name] - break - - if method_info is None: - raise ValueError(f"Method {method_name} not found in method config") - - # 查找方法参数信息 - parameter_info = None - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if method_name in self.parameter_config.get(category, {}): - parameter_info = self.parameter_config[category][method_name] - break - - if parameter_info is None: - raise ValueError(f"Method {method_name} not found in parameter config") - - # 组合返回信息 - return { - "status": "success", - "method": { - "name": method_name, - "description": parameter_info.get('description', ''), - "principle": method_info.get('principle', ''), - "advantages": method_info.get('advantages', []), - "disadvantages": method_info.get('disadvantages', []), - "applicable_scenarios": method_info.get('applicable_scenarios', []), - "parameters": parameter_info.get('parameters', []) - } - } - - except Exception as e: - self.logger.error(f"Error getting method details: {str(e)}") - return { - "status": "error", - "error": str(e) - } \ No newline at end of file diff --git a/function/model_manager.py b/function/model_manager.py index b8c36c5..9137c98 100644 --- a/function/model_manager.py +++ b/function/model_manager.py @@ -18,6 +18,11 @@ from sklearn.metrics import ( import torch from torch.utils.data import DataLoader, TensorDataset + +''' + 模型管理整体集成 +''' + class ModelManager: """模型管理类""" @@ -27,12 +32,33 @@ class ModelManager: self.logger = logging.getLogger(__name__) self._setup_logging() self._metrics_map() + self.method_config = self._load_metrics() + + self.method_config = self._load_model_config() + self.parameter_config = self._load_parameter_config() # 初始化MLflow客户端 self.mlflow_uri = self.config.get('mlflow_uri', 'http://10.0.0.202:5000') mlflow.set_tracking_uri(self.mlflow_uri) self.client = MlflowClient() + def _load_metrics(self) -> Dict: + """加载方法配置文件""" + try: + config_path = Path('model/metrics.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Method config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + self.logger.info("Successfully loaded method config") + return config + + except Exception as e: + self.logger.error(f"Error loading method config: {str(e)}") + raise + def _setup_logging(self): """设置日志""" log_dir = Path('.log') @@ -47,6 +73,54 @@ class ModelManager: self.logger.addHandler(file_handler) self.logger.setLevel(logging.INFO) + def _load_model_config(self) -> Dict: + """加载方法配置文件""" + try: + config_path = Path('model/model.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Method config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config_model = yaml.safe_load(f) + + self.logger.info("Successfully loaded model config") + + + config_path = Path('model/metrics.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Metrics config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config_metric = yaml.safe_load(f) + + self.logger.info("Successfully loaded metrics config") + + + config = {**config_model, **config_metric} + return config + + + except Exception as e: + self.logger.error(f"Error loading method or metric config: {str(e)}") + raise + + def _load_parameter_config(self) -> Dict: + """加载参数配置文件""" + try: + config_path = Path('model/parameter.yaml') + if not config_path.exists(): + raise FileNotFoundError(f"Parameter config file not found at {config_path}") + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + self.logger.info("Successfully loaded parameter config") + return config + except Exception as e: + self.logger.error(f"Error loading parameter config: {str(e)}") + raise + + def _metrics_map(self): self.metrics_map={ 'accuracy' : accuracy_score, @@ -63,7 +137,99 @@ class ModelManager: 'completeness': completeness_score, 'silhouette' : silhouette_score } + + + def _get_algorithm_info(self, algorithm_name: str) -> Dict: + """获取算法信息""" + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if algorithm_name in self.method_config.get(category, {}): + return self.method_config[category][algorithm_name] + raise ValueError(f"Algorithm {algorithm_name} not found in model info") + + + def _get_model_class(self, algorithm_name: str): + """获取模型类""" + # 分类算法 + from sklearn.linear_model import LogisticRegression + from sklearn.svm import SVC, OneClassSVM + from sklearn.tree import DecisionTreeClassifier + from sklearn.ensemble import ( + RandomForestClassifier, GradientBoostingClassifier, + AdaBoostClassifier, IsolationForest + ) + from sklearn.naive_bayes import GaussianNB + from sklearn.neighbors import KNeighborsClassifier + from sklearn.neural_network import MLPClassifier + import xgboost as xgb + import lightgbm as lgb + from catboost import CatBoostClassifier + # 回归算法 + from sklearn.linear_model import ( + LinearRegression, Ridge, Lasso, + ElasticNet + ) + from sklearn.svm import SVR + from sklearn.tree import DecisionTreeRegressor + from sklearn.ensemble import ( + RandomForestRegressor, GradientBoostingRegressor, + AdaBoostRegressor + ) + from catboost import CatBoostRegressor + from sklearn.neural_network import MLPRegressor + + # 聚类算法 + from sklearn.cluster import ( + KMeans, AgglomerativeClustering, + DBSCAN, SpectralClustering + ) + from sklearn.mixture import GaussianMixture + + algorithm_map = { + # 分类算法 + 'LogisticRegression': LogisticRegression, + 'SVC': SVC, + 'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现 + 'DecisionTreeClassifier': DecisionTreeClassifier, + 'RandomForestClassifier': RandomForestClassifier, + 'XGBClassifier': xgb.XGBClassifier, + 'AdaBoostClassifier': AdaBoostClassifier, + 'CatBoostClassifier': CatBoostClassifier, + 'LGBMClassifier': lgb.LGBMClassifier, + 'GaussianNB': GaussianNB, + 'KNeighborsClassifier': KNeighborsClassifier, + 'MLPClassifier': MLPClassifier, + 'GradientBoostingClassifier': GradientBoostingClassifier, + + # 回归算法 + 'LinearRegression': LinearRegression, + 'Ridge': Ridge, + 'Lasso': Lasso, + 'ElasticNet': ElasticNet, + 'SVR': SVR, + 'DecisionTreeRegressor': DecisionTreeRegressor, + 'RandomForestRegressor': RandomForestRegressor, + 'XGBRegressor': xgb.XGBRegressor, + 'AdaBoostRegressor': AdaBoostRegressor, + 'CatBoostRegressor': CatBoostRegressor, + 'LGBMRegressor': lgb.LGBMRegressor, + 'MLPRegressor': MLPRegressor, + + # 聚类算法 + 'KMeans': KMeans, + 'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制 + 'AgglomerativeClustering': AgglomerativeClustering, + 'DBSCAN': DBSCAN, + 'GaussianMixture': GaussianMixture, + 'SpectralClustering': SpectralClustering + } + + if algorithm_name not in algorithm_map: + raise ValueError(f"Unknown algorithm: {algorithm_name}") + + return algorithm_map[algorithm_name] + + def get_finished_models( self, page: int = 1, @@ -440,7 +606,96 @@ class ModelManager: 'execution_time': f"{execution_time:.2f}s" } } + + def get_models(self) -> Dict: + """获取预处理方法列表""" + try: + models = [] + # 分类方法 + classification_algorithms = list(self.method_config.get('classification_algorithms', {}).keys()) + if classification_algorithms: + models.append({ + "name": "classification_algorithms", + "description": "分类方法", + "method": classification_algorithms + }) + + # 回归方法 + regression_algorithms = list(self.method_config.get('regression_algorithms', {}).keys()) + if regression_algorithms: + models.append({ + "name": "regression_algorithms", + "description": "回归方法", + "method": regression_algorithms + }) + + # 聚类方法 + clustering_algorithms = list(self.method_config.get('clustering_algorithms', {}).keys()) + if clustering_algorithms: + models.append({ + "name": "clustering_algorithms", + "description": "聚类方法", + "method": clustering_algorithms + }) + + return { + "status": "success", + "models": models + } + + except Exception as e: + self.logger.error(f"Error getting preprocessing methods: {str(e)}") + return { + "status": "error", + "error": str(e) + } + + def get_model_details(self, method_name: str) -> Dict: + """获取指定方法的详细信息""" + try: + # 在各个方法类别中查找方法原理和优缺点 + method_info = None + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if method_name in self.method_config.get(category, {}): + method_info = self.method_config[category][method_name] + break + + if method_info is None: + raise ValueError(f"Method {method_name} not found in method config") + + # 查找方法参数信息 + parameter_info = None + for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: + if method_name in self.parameter_config.get(category, {}): + parameter_info = self.parameter_config[category][method_name] + break + + if parameter_info is None: + raise ValueError(f"Method {method_name} not found in parameter config") + + # 组合返回信息 + return { + "status": "success", + "method": { + "name": method_name, + "description": parameter_info.get('description', ''), + "principle": method_info.get('principle', ''), + "advantages": method_info.get('advantages', []), + "disadvantages": method_info.get('disadvantages', []), + "applicable_scenarios": method_info.get('applicable_scenarios', []), + "parameters": parameter_info.get('parameters', []) + } + } + + except Exception as e: + self.logger.error(f"Error getting method details: {str(e)}") + return { + "status": "error", + "error": str(e) + } + + # except Exception as e: # error_msg = f"预测过程发生错误: {str(e)}" # self.logger.error(error_msg) @@ -451,4 +706,184 @@ class ModelManager: # 'error_type': type(e).__name__, # 'error_message': str(e) # } - # } \ No newline at end of file + # } + + def get_metrics(self) -> Dict: + """获取预处理方法列表""" + try: + metrics = [] + + # 分类方法 + classification_metrics = self.method_config.get('classification', {}) + if classification_metrics: + metrics.append({ + "name": "classification_metrics", + "description": "分类方法评价指标", + "metric": classification_metrics + }) + + # 回归方法 + regression_metrics = self.method_config.get('regression', {}) + if regression_metrics: + metrics.append({ + "name": "regression_metrics", + "description": "回归方法评价指标", + "metric": regression_metrics + }) + + # 聚类方法 + clustering_metrics = self.method_config.get('clustering', {}) + if clustering_metrics: + metrics.append({ + "name": "clustering_metrics", + "description": "聚类方法评价指标", + "metric": clustering_metrics + }) + + return { + "status": "success", + "metric": metrics + } + + except Exception as e: + self.logger.error(f"Error getting preprocessing methods: {str(e)}") + return { + "status": "error", + "error": str(e) + } + def train_model( + self, + train_data: Dict, + val_data: Dict, + model_config: Dict, + experiment_name: str + ) -> Dict: + """ + 训练模型 + + Args: + train_data: 训练数据,包含特征和标签 + val_data: 验证数据,包含特征和标签 + model_config: 模型配置,包含算法名称和参数 + experiment_name: MLflow实验名称 + + Returns: + 训练结果字典 + """ + try: + # 检查实验是否存在且被删除 + experiment = mlflow.get_experiment_by_name(experiment_name) + if experiment and experiment.lifecycle_stage == 'deleted': + # 如果实验被删除,则创建一个新的实验名称 + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + new_experiment_name = f"{experiment_name}_{timestamp}" + self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}") + experiment_name = new_experiment_name + + # 设置MLflow实验 + mlflow.set_experiment(experiment_name) + + with mlflow.start_run() as run: + # 记录基本信息 + mlflow.log_param('algorithm', model_config['algorithm']) + mlflow.log_param('task_type', model_config['task_type']) + # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 + mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 + + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + # mlflow.log_param('start_time', timestamp) + + # 记录模型参数 + for param_name, param_value in model_config['params'].items(): + mlflow.log_param(param_name, param_value) + + # 记录算法信息 + algorithm_info = self._get_algorithm_info(model_config['algorithm']) + mlflow.log_param('principle', algorithm_info['principle']) + mlflow.log_param('advantages', str(algorithm_info['advantages'])) + mlflow.log_param('disadvantages', str(algorithm_info['disadvantages'])) + + # 特殊处理KMeans++ + if model_config['algorithm'] == 'KMeansPlusPlus': + model_config['params']['init'] = 'k-means++' + + # 获取模型类和信息 + model_class = self._get_model_class(model_config['algorithm']) + + # 创建模型实例 + model = model_class(**model_config['params']) + + # 训练模型 + self.logger.info(f"Starting training {model_config['algorithm']}") + model.fit(train_data['features'], train_data['labels']) + + # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + # mlflow.log_param('end_time', timestamp) + + # 在验证集上评估 + val_predictions = model.predict(val_data['features']) + metrics = self._calculate_metrics( + val_data['labels'], + val_predictions, + model_config['task_type'] + ) + + # 记录指标 + for metric_name, metric_value in metrics.items(): + mlflow.log_metric(metric_name, metric_value) + + # 保存模型 + mlflow.sklearn.log_model(model, "model") + + self.logger.info(f"Training completed. Run ID: {run.info.run_id}") + + return { + 'status': 'success', + 'run_id': run.info.run_id, + 'metrics': metrics, + 'algorithm_info': algorithm_info + } + + except Exception as e: + error_msg = f"Error training model: {str(e)}" + self.logger.error(error_msg) + return { + 'status': 'error', + 'message': error_msg + } + + def _calculate_metrics( + self, + true_labels: np.ndarray, + predictions: np.ndarray, + task_type: str + ) -> Dict: + """计算评估指标""" + metrics = {} + + if task_type == 'classification': + metrics['accuracy'] = accuracy_score(true_labels, predictions) + metrics['precision'] = precision_score(true_labels, predictions, average='weighted') + metrics['recall'] = recall_score(true_labels, predictions, average='weighted') + metrics['f1'] = f1_score(true_labels, predictions, average='weighted') + if len(np.unique(true_labels)) == 2: # 二分类问题 + metrics['roc_auc'] = roc_auc_score(true_labels, predictions) + + elif task_type == 'regression': + metrics['mae'] = mean_absolute_error(true_labels, predictions) + metrics['mse'] = mean_squared_error(true_labels, predictions) + metrics['rmse'] = np.sqrt(metrics['mse']) + metrics['r2'] = r2_score(true_labels, predictions) + metrics['explained_variance'] = explained_variance_score(true_labels, predictions) + + elif task_type == 'clustering': + metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions) + metrics['homogeneity'] = homogeneity_score(true_labels, predictions) + metrics['completeness'] = completeness_score(true_labels, predictions) + if len(np.unique(predictions)) > 1: # 确保有多个簇 + metrics['silhouette'] = silhouette_score( + true_labels.reshape(-1, 1), + predictions + ) + + return metrics \ No newline at end of file diff --git a/function/model_trainer.py b/function/model_trainer.py deleted file mode 100644 index 3556652..0000000 --- a/function/model_trainer.py +++ /dev/null @@ -1,298 +0,0 @@ -import numpy as pd -import numpy as np -from typing import Dict, List, Optional -import logging -from pathlib import Path -import datetime -import yaml -import mlflow -from sklearn.metrics import ( - accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, - mean_absolute_error, mean_squared_error, r2_score, explained_variance_score, - adjusted_rand_score, homogeneity_score, completeness_score, silhouette_score -) - -class ModelTrainer: - """模型训练类""" - - def __init__(self, config: Dict = None): - """初始化模型训练器""" - self.config = config or {} - self.logger = logging.getLogger(__name__) - self._setup_logging() - self._load_metrics() - self._load_parameters() - self._load_model_info() - - # with open("confg/config.yaml", 'r', encoding='utf-8') as f: - # config = yaml.safe_load(f) - - # 初始化MLflow - mlflow.set_tracking_uri(self.config.get('mlflow_uri', 'http://10.0.0.202:5000')) - - def _setup_logging(self): - """设置日志""" - log_dir = Path('.log') - log_dir.mkdir(exist_ok=True) - - file_handler = logging.FileHandler( - log_dir / f'model_training_{datetime.datetime.now():%Y%m%d_%H%M%S}.log' - ) - file_handler.setFormatter( - logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - ) - self.logger.addHandler(file_handler) - self.logger.setLevel(logging.INFO) - - def _load_metrics(self): - """加载评估指标配置""" - try: - with open('model/metrics.yaml', 'r', encoding='utf-8') as f: - self.metrics_config = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading metrics config: {str(e)}") - raise - - def _load_parameters(self): - """加载模型参数配置""" - try: - with open('model/parameter.yaml', 'r', encoding='utf-8') as f: - self.parameter_config = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading parameter config: {str(e)}") - raise - - def _load_model_info(self): - """加载模型信息配置""" - try: - with open('model/model.yaml', 'r', encoding='utf-8') as f: - self.model_info = yaml.safe_load(f) - except Exception as e: - self.logger.error(f"Error loading model info: {str(e)}") - raise - - def _get_model_class(self, algorithm_name: str): - """获取模型类""" - # 分类算法 - from sklearn.linear_model import LogisticRegression - from sklearn.svm import SVC, OneClassSVM - from sklearn.tree import DecisionTreeClassifier - from sklearn.ensemble import ( - RandomForestClassifier, GradientBoostingClassifier, - AdaBoostClassifier, IsolationForest - ) - from sklearn.naive_bayes import GaussianNB - from sklearn.neighbors import KNeighborsClassifier - from sklearn.neural_network import MLPClassifier - import xgboost as xgb - import lightgbm as lgb - from catboost import CatBoostClassifier - - # 回归算法 - from sklearn.linear_model import ( - LinearRegression, Ridge, Lasso, - ElasticNet - ) - from sklearn.svm import SVR - from sklearn.tree import DecisionTreeRegressor - from sklearn.ensemble import ( - RandomForestRegressor, GradientBoostingRegressor, - AdaBoostRegressor - ) - from catboost import CatBoostRegressor - from sklearn.neural_network import MLPRegressor - - # 聚类算法 - from sklearn.cluster import ( - KMeans, AgglomerativeClustering, - DBSCAN, SpectralClustering - ) - from sklearn.mixture import GaussianMixture - - algorithm_map = { - # 分类算法 - 'LogisticRegression': LogisticRegression, - 'SVC': SVC, - 'SVDD': OneClassSVM, # SVDD使用OneClassSVM实现 - 'DecisionTreeClassifier': DecisionTreeClassifier, - 'RandomForestClassifier': RandomForestClassifier, - 'XGBClassifier': xgb.XGBClassifier, - 'AdaBoostClassifier': AdaBoostClassifier, - 'CatBoostClassifier': CatBoostClassifier, - 'LGBMClassifier': lgb.LGBMClassifier, - 'GaussianNB': GaussianNB, - 'KNeighborsClassifier': KNeighborsClassifier, - 'MLPClassifier': MLPClassifier, - 'GradientBoostingClassifier': GradientBoostingClassifier, - - # 回归算法 - 'LinearRegression': LinearRegression, - 'Ridge': Ridge, - 'Lasso': Lasso, - 'ElasticNet': ElasticNet, - 'SVR': SVR, - 'DecisionTreeRegressor': DecisionTreeRegressor, - 'RandomForestRegressor': RandomForestRegressor, - 'XGBRegressor': xgb.XGBRegressor, - 'AdaBoostRegressor': AdaBoostRegressor, - 'CatBoostRegressor': CatBoostRegressor, - 'LGBMRegressor': lgb.LGBMRegressor, - 'MLPRegressor': MLPRegressor, - - # 聚类算法 - 'KMeans': KMeans, - 'KMeansPlusPlus': KMeans, # KMeans++使用KMeans实现,通过init参数控制 - 'AgglomerativeClustering': AgglomerativeClustering, - 'DBSCAN': DBSCAN, - 'GaussianMixture': GaussianMixture, - 'SpectralClustering': SpectralClustering - } - - if algorithm_name not in algorithm_map: - raise ValueError(f"Unknown algorithm: {algorithm_name}") - - return algorithm_map[algorithm_name] - - def _get_algorithm_info(self, algorithm_name: str) -> Dict: - """获取算法信息""" - for category in ['classification_algorithms', 'regression_algorithms', 'clustering_algorithms']: - if algorithm_name in self.model_info.get(category, {}): - return self.model_info[category][algorithm_name] - raise ValueError(f"Algorithm {algorithm_name} not found in model info") - - def train_model( - self, - train_data: Dict, - val_data: Dict, - model_config: Dict, - experiment_name: str - ) -> Dict: - """ - 训练模型 - - Args: - train_data: 训练数据,包含特征和标签 - val_data: 验证数据,包含特征和标签 - model_config: 模型配置,包含算法名称和参数 - experiment_name: MLflow实验名称 - - Returns: - 训练结果字典 - """ - try: - # 检查实验是否存在且被删除 - experiment = mlflow.get_experiment_by_name(experiment_name) - if experiment and experiment.lifecycle_stage == 'deleted': - # 如果实验被删除,则创建一个新的实验名称 - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - new_experiment_name = f"{experiment_name}_{timestamp}" - self.logger.info(f"Original experiment was deleted, creating new experiment: {new_experiment_name}") - experiment_name = new_experiment_name - - # 设置MLflow实验 - mlflow.set_experiment(experiment_name) - - with mlflow.start_run() as run: - # 记录基本信息 - mlflow.log_param('algorithm', model_config['algorithm']) - mlflow.log_param('task_type', model_config['task_type']) - # mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称 - mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径 - - # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - # mlflow.log_param('start_time', timestamp) - - # 记录模型参数 - for param_name, param_value in model_config['params'].items(): - mlflow.log_param(param_name, param_value) - - # 记录算法信息 - algorithm_info = self._get_algorithm_info(model_config['algorithm']) - mlflow.log_param('principle', algorithm_info['principle']) - mlflow.log_param('advantages', str(algorithm_info['advantages'])) - mlflow.log_param('disadvantages', str(algorithm_info['disadvantages'])) - - # 特殊处理KMeans++ - if model_config['algorithm'] == 'KMeansPlusPlus': - model_config['params']['init'] = 'k-means++' - - # 获取模型类和信息 - model_class = self._get_model_class(model_config['algorithm']) - - # 创建模型实例 - model = model_class(**model_config['params']) - - # 训练模型 - self.logger.info(f"Starting training {model_config['algorithm']}") - model.fit(train_data['features'], train_data['labels']) - - # timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - # mlflow.log_param('end_time', timestamp) - - # 在验证集上评估 - val_predictions = model.predict(val_data['features']) - metrics = self._calculate_metrics( - val_data['labels'], - val_predictions, - model_config['task_type'] - ) - - # 记录指标 - for metric_name, metric_value in metrics.items(): - mlflow.log_metric(metric_name, metric_value) - - # 保存模型 - mlflow.sklearn.log_model(model, "model") - - self.logger.info(f"Training completed. Run ID: {run.info.run_id}") - - return { - 'status': 'success', - 'run_id': run.info.run_id, - 'metrics': metrics, - 'algorithm_info': algorithm_info - } - - except Exception as e: - error_msg = f"Error training model: {str(e)}" - self.logger.error(error_msg) - return { - 'status': 'error', - 'message': error_msg - } - - def _calculate_metrics( - self, - true_labels: np.ndarray, - predictions: np.ndarray, - task_type: str - ) -> Dict: - """计算评估指标""" - metrics = {} - - if task_type == 'classification': - metrics['accuracy'] = accuracy_score(true_labels, predictions) - metrics['precision'] = precision_score(true_labels, predictions, average='weighted') - metrics['recall'] = recall_score(true_labels, predictions, average='weighted') - metrics['f1'] = f1_score(true_labels, predictions, average='weighted') - if len(np.unique(true_labels)) == 2: # 二分类问题 - metrics['roc_auc'] = roc_auc_score(true_labels, predictions) - - elif task_type == 'regression': - metrics['mae'] = mean_absolute_error(true_labels, predictions) - metrics['mse'] = mean_squared_error(true_labels, predictions) - metrics['rmse'] = np.sqrt(metrics['mse']) - metrics['r2'] = r2_score(true_labels, predictions) - metrics['explained_variance'] = explained_variance_score(true_labels, predictions) - - elif task_type == 'clustering': - metrics['adjusted_rand'] = adjusted_rand_score(true_labels, predictions) - metrics['homogeneity'] = homogeneity_score(true_labels, predictions) - metrics['completeness'] = completeness_score(true_labels, predictions) - if len(np.unique(predictions)) > 1: # 确保有多个簇 - metrics['silhouette'] = silhouette_score( - true_labels.reshape(-1, 1), - predictions - ) - - return metrics \ No newline at end of file