49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
from .data_processor import DataProcessor
|
|
import pandas as pd
|
|
from typing import Dict, Tuple
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
class DataSplitter(DataProcessor):
|
|
"""数据集划分类"""
|
|
|
|
def __init__(self, config: Dict = None):
|
|
super().__init__(config)
|
|
|
|
def train_val_test_split(
|
|
self,
|
|
df: pd.DataFrame,
|
|
target: str,
|
|
test_size: float = 0.2,
|
|
val_size: float = 0.2,
|
|
random_state: int = 42
|
|
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
"""划分训练集、验证集和测试集"""
|
|
try:
|
|
# 首先划分训练集和测试集
|
|
train_val, test = train_test_split(
|
|
df,
|
|
test_size=test_size,
|
|
random_state=random_state,
|
|
stratify=df[target] if df[target].dtype == 'object' else None
|
|
)
|
|
|
|
# 再划分训练集和验证集
|
|
train, val = train_test_split(
|
|
train_val,
|
|
test_size=val_size,
|
|
random_state=random_state,
|
|
stratify=train_val[target] if train_val[target].dtype == 'object' else None
|
|
)
|
|
|
|
self.logger.info(f"""
|
|
Data split complete:
|
|
- Training set: {train.shape[0]} samples
|
|
- Validation set: {val.shape[0]} samples
|
|
- Test set: {test.shape[0]} samples
|
|
""")
|
|
|
|
return train, val, test
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error splitting data: {str(e)}")
|
|
raise |