修改--将读取文件集成到模型训练中
This commit is contained in:
parent
aede371f38
commit
d21060c670
@ -51,7 +51,7 @@ async def train_model(request: TrainRequest):
|
||||
"""模型训练"""
|
||||
result = model_manager.train_model(
|
||||
train_data=request.dataset['train'],
|
||||
val_data=request.dataset.get('val'),
|
||||
val_data=request.dataset['val'],
|
||||
model_config={
|
||||
'model_name': request.model,
|
||||
'parameters': request.parameters,
|
||||
|
||||
22
doc/安装文档.md
22
doc/安装文档.md
@ -26,6 +26,28 @@ GET http://10.0.0.202:8992/data/feature/method/{method_name}
|
||||
### 3.5 处理数据集
|
||||
POST http://10.0.0.202:8992/data/process
|
||||
|
||||
传递参数
|
||||
{
|
||||
"input_path": "dataset/dataset_raw/breast_cancer.csv",
|
||||
"output_dir": "dataset/dataset_processed",
|
||||
"process_methods": [
|
||||
{
|
||||
"method_name": "IsolationForest",
|
||||
"params": {
|
||||
"contamination": 0.1,
|
||||
"random_state": 42
|
||||
}
|
||||
}
|
||||
],
|
||||
"feature_methods": [
|
||||
|
||||
],
|
||||
"split_params": {
|
||||
"test_size": 0.1,
|
||||
"val_size": 0.2
|
||||
}
|
||||
}
|
||||
|
||||
### 3.6 获取可用数据集列表
|
||||
GET http://10.0.0.202:8992/data/datasets
|
||||
|
||||
|
||||
@ -71,15 +71,7 @@ print("--------------------------------------------获取所有已训练模型 e
|
||||
|
||||
|
||||
print("--------------------------------------------模型训练---------------------------------------------------")
|
||||
# 加载数据
|
||||
train_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/train_breast_cancer_20250219_144629.csv')
|
||||
val_data = pd.read_csv('/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629/val_breast_cancer_20250219_144629.csv')
|
||||
|
||||
# 准备特征和标签
|
||||
X_train = train_data.drop('target', axis=1)
|
||||
y_train = train_data['target']
|
||||
X_val = val_data.drop('target', axis=1)
|
||||
y_val = val_data['target']
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
@ -98,14 +90,8 @@ model_config = {
|
||||
# 模型文件 直接在 mlruns/文件夹下
|
||||
for i in range(3, 4):
|
||||
result = manager.train_model(
|
||||
{
|
||||
'features': X_train,
|
||||
'labels': y_train
|
||||
},
|
||||
{
|
||||
'features': X_val,
|
||||
'labels': y_val
|
||||
},
|
||||
'/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/train_breast_cancer_20250224_170615.csv',
|
||||
'/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250224_170615/val_breast_cancer_20250224_170615.csv',
|
||||
model_config,
|
||||
f'breast_cancer_classification_{i}'
|
||||
)
|
||||
|
||||
Binary file not shown.
@ -771,8 +771,8 @@ class ModelManager:
|
||||
}
|
||||
def train_model(
|
||||
self,
|
||||
train_data: Dict,
|
||||
val_data: Dict,
|
||||
train_path: str,
|
||||
val_path: str,
|
||||
model_config: Dict,
|
||||
experiment_name: str
|
||||
) -> Dict:
|
||||
@ -789,6 +789,11 @@ class ModelManager:
|
||||
训练结果字典
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 检查实验是否存在且被删除
|
||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||
if experiment and experiment.lifecycle_stage == 'deleted':
|
||||
@ -800,6 +805,28 @@ class ModelManager:
|
||||
|
||||
# 设置MLflow实验
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
if os.path.exists(train_path):
|
||||
# 加载数据
|
||||
train_data = pd.read_csv(train_path)
|
||||
else:
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': '找不到训练集路径'
|
||||
}
|
||||
if os.path.exists(val_path):
|
||||
val_data = pd.read_csv(val_path)
|
||||
else:
|
||||
return{
|
||||
'status': 'error',
|
||||
'message': '找不到验证集路径'
|
||||
}
|
||||
|
||||
# 准备特征和标签
|
||||
X_train = train_data.drop('target', axis=1)
|
||||
y_train = train_data['target']
|
||||
X_val = val_data.drop('target', axis=1)
|
||||
y_val = val_data['target']
|
||||
|
||||
with mlflow.start_run() as run:
|
||||
# 记录基本信息
|
||||
@ -833,15 +860,15 @@ class ModelManager:
|
||||
|
||||
# 训练模型
|
||||
self.logger.info(f"Starting training {model_config['algorithm']}")
|
||||
model.fit(train_data['features'], train_data['labels'])
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('end_time', timestamp)
|
||||
|
||||
# 在验证集上评估
|
||||
val_predictions = model.predict(val_data['features'])
|
||||
val_predictions = model.predict(X_val)
|
||||
metrics = self._calculate_metrics(
|
||||
val_data['labels'],
|
||||
y_val,
|
||||
val_predictions,
|
||||
model_config['task_type']
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user