修改--修改模型训练接口的参数

This commit is contained in:
haotian 2025-02-25 09:49:27 +08:00
parent d21060c670
commit 382271e424
4 changed files with 17 additions and 9 deletions

View File

@ -8,10 +8,14 @@ model_manager = ModelManager()
# 数据模型
class TrainRequest(BaseModel):
model: str
dataset: Dict[str, str]
train_path: str
val_path: str
algorithm: str
task_type : str
parameters: Dict
metrics: List[str]
experiment_name: str
class PredictRequest(BaseModel):
run_id: str
@ -50,13 +54,15 @@ async def get_metrics():
async def train_model(request: TrainRequest):
"""模型训练"""
result = model_manager.train_model(
train_data=request.dataset['train'],
val_data=request.dataset['val'],
train_path=request.train_path,
val_path=request.val_path,
model_config={
'model_name': request.model,
'parameters': request.parameters,
'algorithm': request.algorithm,
'task_type': request.task_type,
'params': request.parameters,
'metrics': request.metrics
}
},
experiment_name=request.experiment_name
)
if result['status'] == 'error':
raise HTTPException(status_code=500, detail=result['message'])

View File

@ -77,7 +77,7 @@ print("--------------------------------------------模型训练-----------------
model_config = {
'algorithm': 'XGBClassifier',
'task_type': 'classification',
'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
# 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
'params': {
'n_estimators': 100,
'learning_rate': 0.1,

View File

@ -833,7 +833,9 @@ class ModelManager:
mlflow.log_param('algorithm', model_config['algorithm'])
mlflow.log_param('task_type', model_config['task_type'])
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
mlflow.log_param('dataset_train', train_path) # 直接写数据集路径
mlflow.log_param('dataset_val', val_path)
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
# mlflow.log_param('start_time', timestamp)