85 lines
2.2 KiB
Python
85 lines
2.2 KiB
Python
import unittest
|
|
import numpy as np
|
|
from function.model_trainer import ModelTrainer
|
|
|
|
class TestModelTrainer(unittest.TestCase):
|
|
def setUp(self):
|
|
self.trainer = ModelTrainer()
|
|
|
|
# 创建测试数据
|
|
np.random.seed(42)
|
|
self.X_train = np.random.randn(100, 5)
|
|
self.y_train = np.random.randint(0, 2, 100)
|
|
self.X_val = np.random.randn(30, 5)
|
|
self.y_val = np.random.randint(0, 2, 30)
|
|
|
|
def test_train_model(self):
|
|
# 准备训练数据
|
|
train_data = {
|
|
'features': self.X_train,
|
|
'labels': self.y_train
|
|
}
|
|
|
|
val_data = {
|
|
'features': self.X_val,
|
|
'labels': self.y_val
|
|
}
|
|
|
|
# 模型配置
|
|
model_config = {
|
|
'algorithm': 'LogisticRegression',
|
|
'task_type': 'classification',
|
|
'params': {
|
|
'random_state': 42
|
|
}
|
|
}
|
|
|
|
# 训练模型
|
|
result = self.trainer.train_model(
|
|
train_data,
|
|
val_data,
|
|
model_config,
|
|
'test_experiment'
|
|
)
|
|
|
|
# 验证结果
|
|
self.assertEqual(result['status'], 'success')
|
|
self.assertIn('run_id', result)
|
|
self.assertIn('metrics', result)
|
|
|
|
# 验证指标
|
|
metrics = result['metrics']
|
|
self.assertIn('accuracy', metrics)
|
|
self.assertIn('precision', metrics)
|
|
self.assertIn('recall', metrics)
|
|
self.assertIn('f1', metrics)
|
|
|
|
def test_invalid_algorithm(self):
|
|
# 测试无效的算法名
|
|
train_data = {
|
|
'features': self.X_train,
|
|
'labels': self.y_train
|
|
}
|
|
|
|
val_data = {
|
|
'features': self.X_val,
|
|
'labels': self.y_val
|
|
}
|
|
|
|
model_config = {
|
|
'algorithm': 'InvalidAlgorithm',
|
|
'task_type': 'classification',
|
|
'params': {}
|
|
}
|
|
|
|
result = self.trainer.train_model(
|
|
train_data,
|
|
val_data,
|
|
model_config,
|
|
'test_experiment'
|
|
)
|
|
|
|
self.assertEqual(result['status'], 'error')
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |