93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
import unittest
|
|
import pandas as pd
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from function.data_processor_date import DataProcessor
|
|
|
|
class TestDataProcessor(unittest.TestCase):
|
|
def setUp(self):
|
|
self.processor = DataProcessor()
|
|
|
|
# 创建测试数据
|
|
self.test_data = pd.DataFrame({
|
|
'feature1': [1, 2, np.nan, 4, 5],
|
|
'feature2': [10, 20, 30, 40, 50],
|
|
'target': [0, 1, 0, 1, 0]
|
|
})
|
|
|
|
# 保存测试数据
|
|
self.input_path = 'dataset/dataset_raw/test_data.csv'
|
|
Path(self.input_path).parent.mkdir(parents=True, exist_ok=True)
|
|
self.test_data.to_csv(self.input_path, index=False)
|
|
|
|
# 设置输出目录
|
|
self.output_dir = 'dataset/dataset_processed'
|
|
|
|
def test_process_dataset(self):
|
|
# 定义处理方法
|
|
cleaning_methods = [
|
|
{
|
|
'method_name': 'SimpleImputer',
|
|
'params': {'strategy': 'mean'}
|
|
}
|
|
]
|
|
|
|
feature_methods = [
|
|
{
|
|
'method_name': 'StandardScaler',
|
|
'params': {}
|
|
}
|
|
]
|
|
|
|
split_params = {
|
|
'test_size': 0.2,
|
|
'val_size': 0.2
|
|
}
|
|
|
|
# 处理数据集
|
|
result = self.processor.process_dataset(
|
|
self.input_path,
|
|
self.output_dir,
|
|
cleaning_methods,
|
|
feature_methods,
|
|
split_params
|
|
)
|
|
|
|
# 验证结果
|
|
self.assertEqual(result['status'], 'success')
|
|
self.assertIn('process_record', result)
|
|
|
|
# 验证输出文件
|
|
record = result['process_record']
|
|
self.assertTrue(Path(record['output_files']['train']).exists())
|
|
self.assertTrue(Path(record['output_files']['validation']).exists())
|
|
self.assertTrue(Path(record['output_files']['test']).exists())
|
|
|
|
def test_invalid_method(self):
|
|
# 测试无效的方法名
|
|
cleaning_methods = [
|
|
{
|
|
'method_name': 'InvalidMethod',
|
|
'params': {}
|
|
}
|
|
]
|
|
|
|
result = self.processor.process_dataset(
|
|
self.input_path,
|
|
self.output_dir,
|
|
cleaning_methods,
|
|
[],
|
|
{'test_size': 0.2, 'val_size': 0.2}
|
|
)
|
|
|
|
self.assertEqual(result['status'], 'error')
|
|
|
|
def tearDown(self):
|
|
# 清理测试文件
|
|
try:
|
|
Path(self.input_path).unlink()
|
|
except:
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |