from .data_processor import DataProcessor import pandas as pd from typing import Dict, Tuple from sklearn.model_selection import train_test_split class DataSplitter(DataProcessor): """数据集划分类""" def __init__(self, config: Dict = None): super().__init__(config) def train_val_test_split( self, df: pd.DataFrame, target: str, test_size: float = 0.2, val_size: float = 0.2, random_state: int = 42 ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """划分训练集、验证集和测试集""" try: # 首先划分训练集和测试集 train_val, test = train_test_split( df, test_size=test_size, random_state=random_state, stratify=df[target] if df[target].dtype == 'object' else None ) # 再划分训练集和验证集 train, val = train_test_split( train_val, test_size=val_size, random_state=random_state, stratify=train_val[target] if train_val[target].dtype == 'object' else None ) self.logger.info(f""" Data split complete: - Training set: {train.shape[0]} samples - Validation set: {val.shape[0]} samples - Test set: {test.shape[0]} samples """) return train, val, test except Exception as e: self.logger.error(f"Error splitting data: {str(e)}") raise