MLPlatform/test_model_trainer.py

85 lines
2.2 KiB
Python

import unittest
import numpy as np
from function.model_trainer import ModelTrainer
class TestModelTrainer(unittest.TestCase):
def setUp(self):
self.trainer = ModelTrainer()
# 创建测试数据
np.random.seed(42)
self.X_train = np.random.randn(100, 5)
self.y_train = np.random.randint(0, 2, 100)
self.X_val = np.random.randn(30, 5)
self.y_val = np.random.randint(0, 2, 30)
def test_train_model(self):
# 准备训练数据
train_data = {
'features': self.X_train,
'labels': self.y_train
}
val_data = {
'features': self.X_val,
'labels': self.y_val
}
# 模型配置
model_config = {
'algorithm': 'LogisticRegression',
'task_type': 'classification',
'params': {
'random_state': 42
}
}
# 训练模型
result = self.trainer.train_model(
train_data,
val_data,
model_config,
'test_experiment'
)
# 验证结果
self.assertEqual(result['status'], 'success')
self.assertIn('run_id', result)
self.assertIn('metrics', result)
# 验证指标
metrics = result['metrics']
self.assertIn('accuracy', metrics)
self.assertIn('precision', metrics)
self.assertIn('recall', metrics)
self.assertIn('f1', metrics)
def test_invalid_algorithm(self):
# 测试无效的算法名
train_data = {
'features': self.X_train,
'labels': self.y_train
}
val_data = {
'features': self.X_val,
'labels': self.y_val
}
model_config = {
'algorithm': 'InvalidAlgorithm',
'task_type': 'classification',
'params': {}
}
result = self.trainer.train_model(
train_data,
val_data,
model_config,
'test_experiment'
)
self.assertEqual(result['status'], 'error')
if __name__ == '__main__':
unittest.main()