修改--修改模型训练接口的参数
This commit is contained in:
parent
d21060c670
commit
382271e424
@ -8,10 +8,14 @@ model_manager = ModelManager()
|
||||
|
||||
# 数据模型
|
||||
class TrainRequest(BaseModel):
|
||||
model: str
|
||||
dataset: Dict[str, str]
|
||||
|
||||
train_path: str
|
||||
val_path: str
|
||||
algorithm: str
|
||||
task_type : str
|
||||
parameters: Dict
|
||||
metrics: List[str]
|
||||
experiment_name: str
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
run_id: str
|
||||
@ -50,13 +54,15 @@ async def get_metrics():
|
||||
async def train_model(request: TrainRequest):
|
||||
"""模型训练"""
|
||||
result = model_manager.train_model(
|
||||
train_data=request.dataset['train'],
|
||||
val_data=request.dataset['val'],
|
||||
train_path=request.train_path,
|
||||
val_path=request.val_path,
|
||||
model_config={
|
||||
'model_name': request.model,
|
||||
'parameters': request.parameters,
|
||||
'algorithm': request.algorithm,
|
||||
'task_type': request.task_type,
|
||||
'params': request.parameters,
|
||||
'metrics': request.metrics
|
||||
}
|
||||
},
|
||||
experiment_name=request.experiment_name
|
||||
)
|
||||
if result['status'] == 'error':
|
||||
raise HTTPException(status_code=500, detail=result['message'])
|
||||
|
||||
@ -77,7 +77,7 @@ print("--------------------------------------------模型训练-----------------
|
||||
model_config = {
|
||||
'algorithm': 'XGBClassifier',
|
||||
'task_type': 'classification',
|
||||
'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
|
||||
# 'dataset' : '/home/admin-root/haotian/MLPlatform/dataset/dataset_processed/breast_cancer_20250219_144629',
|
||||
'params': {
|
||||
'n_estimators': 100,
|
||||
'learning_rate': 0.1,
|
||||
|
||||
Binary file not shown.
@ -833,7 +833,9 @@ class ModelManager:
|
||||
mlflow.log_param('algorithm', model_config['algorithm'])
|
||||
mlflow.log_param('task_type', model_config['task_type'])
|
||||
# mlflow.log_param('dataset', experiment_name.split('_')[0]) # 从实验名称提取数据集名称
|
||||
mlflow.log_param('dataset', model_config['dataset']) # 直接写数据集路径
|
||||
mlflow.log_param('dataset_train', train_path) # 直接写数据集路径
|
||||
mlflow.log_param('dataset_val', val_path)
|
||||
|
||||
|
||||
# timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
# mlflow.log_param('start_time', timestamp)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user