MLPlatform/function_old/data_process/data_splitter.py

49 lines
1.6 KiB
Python

from .data_processor import DataProcessor
import pandas as pd
from typing import Dict, Tuple
from sklearn.model_selection import train_test_split
class DataSplitter(DataProcessor):
"""数据集划分类"""
def __init__(self, config: Dict = None):
super().__init__(config)
def train_val_test_split(
self,
df: pd.DataFrame,
target: str,
test_size: float = 0.2,
val_size: float = 0.2,
random_state: int = 42
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""划分训练集、验证集和测试集"""
try:
# 首先划分训练集和测试集
train_val, test = train_test_split(
df,
test_size=test_size,
random_state=random_state,
stratify=df[target] if df[target].dtype == 'object' else None
)
# 再划分训练集和验证集
train, val = train_test_split(
train_val,
test_size=val_size,
random_state=random_state,
stratify=train_val[target] if train_val[target].dtype == 'object' else None
)
self.logger.info(f"""
Data split complete:
- Training set: {train.shape[0]} samples
- Validation set: {val.shape[0]} samples
- Test set: {test.shape[0]} samples
""")
return train, val, test
except Exception as e:
self.logger.error(f"Error splitting data: {str(e)}")
raise