修改--将读取文件集成到模型训练中

This commit is contained in:
haotian 2025-02-25 09:38:13 +08:00
parent aede371f38
commit d21060c670
5 changed files with 57 additions and 22 deletions

View File

@ -51,7 +51,7 @@ async def train_model(request: TrainRequest):
"""模型训练"""
result = model_manager.train_model(
train_data=request.dataset['train'],
val_data=request.dataset.get('val'),
val_data=request.dataset['val'],
model_config={
'model_name': request.model,
'parameters': request.parameters,

View File

@ -26,6 +26,28 @@ GET http://10.0.0.202:8992/data/feature/method/{method_name}
### 3.5 处理数据集
POST http://10.0.0.202:8992/data/process
传递参数
{
"input_path": "dataset/dataset_raw/breast_cancer.csv",
"output_dir": "dataset/dataset_processed",
"process_methods": [
{
"method_name": "IsolationForest",
"params": {
"contamination": 0.1,
"random_state": 42
}
}
],
"feature_methods": [
],
"split_params": {
"test_size": 0.1,
"val_size": 0.2
}
}
### 3.6 获取可用数据集列表
GET http://10.0.0.202:8992/data/datasets

View File

@ -71,15 +71,7 @@ print("--------------------------------------------获取所有已训练模型 e
print("--------------------------------------------模型训练---------------------------------------------------")
# 加载数据
train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv')
val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv')
# 准备特征和标签
X_train = train_data.drop('target', axis=1)
y_train = train_data['target']
X_val = val_data.drop('target', axis=1)
y_val = val_data['target']
# 模型配置
model_config = {
@ -98,14 +90,8 @@ model_config = {
# 模型文件 直接在 mlruns/文件夹下
for i in range(3, 4):
result = manager.train_model(
{
'features': X_train,
'labels': y_train
},
{
'features': X_val,
'labels': y_val
},
'/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/train_breast_cancer_20250224_170615.csv',
'/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/val_breast_cancer_20250224_170615.csv',
model_config,
f'breast_cancer_classification_{i}'
)

View File

@ -771,8 +771,8 @@ class ModelManager:
}
def train_model(
self,
train_data: Dict,
val_data: Dict,
train_path: str,
val_path: str,
model_config: Dict,
experiment_name: str
) -> Dict:
@ -789,6 +789,11 @@ class ModelManager:
训练结果字典
"""
try:
# 检查实验是否存在且被删除
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment and experiment.lifecycle_stage == 'deleted':
@ -800,6 +805,28 @@ class ModelManager:
# 设置MLflow实验
mlflow.set_experiment(experiment_name)
if os.path.exists(train_path):
# 加载数据
train_data = pd.read_csv(train_path)
else:
return {
'status': 'error',
'message': '找不到训练集路径'
}
if os.path.exists(val_path):
val_data = pd.read_csv(val_path)
else:
return{
'status': 'error',
'message': '找不到验证集路径'
}
# 准备特征和标签
X_train = train_data.drop('target', axis=1)
y_train = train_data['target']
X_val = val_data.drop('target', axis=1)
y_val = val_data['target']
with mlflow.start_run() as run:
# 记录基本信息
@ -833,15 +860,15 @@ class ModelManager:
# 训练模型
self.logger.info(f"Starting training {model_config['algorithm']}")
model.fit(train_data['features'], train_data['labels'])
model.fit(X_train, y_train)
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('end_time', timestamp)
# 在验证集上评估
val_predictions = model.predict(val_data['features'])
val_predictions = model.predict(X_val)
metrics = self._calculate_metrics(
val_data['labels'],
y_val,
val_predictions,
model_config['task_type']
)