import unittest import numpy as np from function.model_trainer import ModelTrainer class TestModelTrainer(unittest.TestCase): def setUp(self): self.trainer = ModelTrainer() # 创建测试数据 np.random.seed(42) self.X_train = np.random.randn(100, 5) self.y_train = np.random.randint(0, 2, 100) self.X_val = np.random.randn(30, 5) self.y_val = np.random.randint(0, 2, 30) def test_train_model(self): # 准备训练数据 train_data = { 'features': self.X_train, 'labels': self.y_train } val_data = { 'features': self.X_val, 'labels': self.y_val } # 模型配置 model_config = { 'algorithm': 'LogisticRegression', 'task_type': 'classification', 'params': { 'random_state': 42 } } # 训练模型 result = self.trainer.train_model( train_data, val_data, model_config, 'test_experiment' ) # 验证结果 self.assertEqual(result['status'], 'success') self.assertIn('run_id', result) self.assertIn('metrics', result) # 验证指标 metrics = result['metrics'] self.assertIn('accuracy', metrics) self.assertIn('precision', metrics) self.assertIn('recall', metrics) self.assertIn('f1', metrics) def test_invalid_algorithm(self): # 测试无效的算法名 train_data = { 'features': self.X_train, 'labels': self.y_train } val_data = { 'features': self.X_val, 'labels': self.y_val } model_config = { 'algorithm': 'InvalidAlgorithm', 'task_type': 'classification', 'params': {} } result = self.trainer.train_model( train_data, val_data, model_config, 'test_experiment' ) self.assertEqual(result['status'], 'error') if __name__ == '__main__': unittest.main()