生成--前端设计
This commit is contained in:
parent
1245847dd0
commit
e43ac40027
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,66 +0,0 @@
|
||||
from .data_processor import DataProcessor
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.ensemble import IsolationForest
|
||||
from scipy import stats
|
||||
|
||||
class DataCleaner(DataProcessor):
|
||||
"""数据清洗类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
super().__init__(config)
|
||||
self.missing_value_methods = {
|
||||
'mean': SimpleImputer(strategy='mean'),
|
||||
'median': SimpleImputer(strategy='median'),
|
||||
'mode': SimpleImputer(strategy='most_frequent'),
|
||||
'constant': SimpleImputer(strategy='constant')
|
||||
}
|
||||
|
||||
def handle_missing_values(self, df: pd.DataFrame, method: str = 'mean', columns: List[str] = None) -> pd.DataFrame:
|
||||
"""处理缺失值"""
|
||||
try:
|
||||
if columns is None:
|
||||
columns = df.select_dtypes(include=[np.number]).columns
|
||||
|
||||
if method not in self.missing_value_methods:
|
||||
raise ValueError(f"Unsupported method: {method}")
|
||||
|
||||
imputer = self.missing_value_methods[method]
|
||||
df[columns] = imputer.fit_transform(df[columns])
|
||||
|
||||
self.logger.info(f"Successfully handled missing values using {method} method")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error handling missing values: {str(e)}")
|
||||
raise
|
||||
|
||||
def remove_duplicates(self, df: pd.DataFrame, subset: List[str] = None) -> pd.DataFrame:
|
||||
"""删除重复值"""
|
||||
try:
|
||||
original_shape = df.shape
|
||||
df = df.drop_duplicates(subset=subset)
|
||||
self.logger.info(f"Removed {original_shape[0] - df.shape[0]} duplicate rows")
|
||||
return df
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error removing duplicates: {str(e)}")
|
||||
raise
|
||||
|
||||
def detect_outliers(self, df: pd.DataFrame, method: str = 'zscore', threshold: float = 3) -> pd.DataFrame:
|
||||
"""检测异常值"""
|
||||
try:
|
||||
if method == 'zscore':
|
||||
z_scores = np.abs(stats.zscore(df.select_dtypes(include=[np.number])))
|
||||
outliers = (z_scores > threshold).any(axis=1)
|
||||
elif method == 'isolation_forest':
|
||||
iso_forest = IsolationForest(contamination=0.1, random_state=42)
|
||||
outliers = iso_forest.fit_predict(df.select_dtypes(include=[np.number])) == -1
|
||||
|
||||
self.logger.info(f"Detected {sum(outliers)} outliers using {method} method")
|
||||
return df[~outliers]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error detecting outliers: {str(e)}")
|
||||
raise
|
||||
@ -1,63 +0,0 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union, Optional
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.model_selection import train_test_split
|
||||
import logging
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
class DataProcessor:
|
||||
"""数据处理基类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
self.config = config or {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""设置日志"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
def load_data(self, file_path: str) -> pd.DataFrame:
|
||||
"""加载数据"""
|
||||
try:
|
||||
file_type = file_path.split('.')[-1].lower()
|
||||
if file_type == 'csv':
|
||||
df = pd.read_csv(file_path, **self.config.get('csv_params', {}))
|
||||
elif file_type == 'parquet':
|
||||
df = pd.read_parquet(file_path)
|
||||
elif file_type == 'hdf5':
|
||||
df = pd.read_hdf(file_path)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_type}")
|
||||
|
||||
self.logger.info(f"Successfully loaded data from {file_path}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error loading data: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_data(self, df: pd.DataFrame, file_path: str):
|
||||
"""保存数据"""
|
||||
try:
|
||||
file_type = file_path.split('.')[-1].lower()
|
||||
if file_type == 'csv':
|
||||
df.to_csv(file_path, index=False)
|
||||
elif file_type == 'parquet':
|
||||
df.to_parquet(file_path)
|
||||
elif file_type == 'hdf5':
|
||||
df.to_hdf(file_path, key='data')
|
||||
|
||||
self.logger.info(f"Successfully saved data to {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error saving data: {str(e)}")
|
||||
raise
|
||||
@ -1,49 +0,0 @@
|
||||
from .data_processor import DataProcessor
|
||||
import pandas as pd
|
||||
from typing import Dict, Tuple
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
class DataSplitter(DataProcessor):
|
||||
"""数据集划分类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
super().__init__(config)
|
||||
|
||||
def train_val_test_split(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
target: str,
|
||||
test_size: float = 0.2,
|
||||
val_size: float = 0.2,
|
||||
random_state: int = 42
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""划分训练集、验证集和测试集"""
|
||||
try:
|
||||
# 首先划分训练集和测试集
|
||||
train_val, test = train_test_split(
|
||||
df,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify=df[target] if df[target].dtype == 'object' else None
|
||||
)
|
||||
|
||||
# 再划分训练集和验证集
|
||||
train, val = train_test_split(
|
||||
train_val,
|
||||
test_size=val_size,
|
||||
random_state=random_state,
|
||||
stratify=train_val[target] if train_val[target].dtype == 'object' else None
|
||||
)
|
||||
|
||||
self.logger.info(f"""
|
||||
Data split complete:
|
||||
- Training set: {train.shape[0]} samples
|
||||
- Validation set: {val.shape[0]} samples
|
||||
- Test set: {test.shape[0]} samples
|
||||
""")
|
||||
|
||||
return train, val, test
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error splitting data: {str(e)}")
|
||||
raise
|
||||
@ -1,77 +0,0 @@
|
||||
from .data_processor import DataProcessor
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
|
||||
from sklearn.feature_selection import SelectKBest, chi2, f_classif
|
||||
|
||||
class FeatureEngineer(DataProcessor):
|
||||
"""特征工程类"""
|
||||
|
||||
def __init__(self, config: Dict = None):
|
||||
super().__init__(config)
|
||||
self.scalers = {
|
||||
'standard': StandardScaler(),
|
||||
'minmax': MinMaxScaler(),
|
||||
'robust': RobustScaler()
|
||||
}
|
||||
|
||||
def scale_features(self, df: pd.DataFrame, method: str = 'standard', columns: List[str] = None) -> pd.DataFrame:
|
||||
"""特征缩放"""
|
||||
try:
|
||||
if columns is None:
|
||||
columns = df.select_dtypes(include=[np.number]).columns
|
||||
|
||||
if method not in self.scalers:
|
||||
raise ValueError(f"Unsupported scaling method: {method}")
|
||||
|
||||
scaler = self.scalers[method]
|
||||
df[columns] = scaler.fit_transform(df[columns])
|
||||
|
||||
self.logger.info(f"Successfully scaled features using {method} method")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error scaling features: {str(e)}")
|
||||
raise
|
||||
|
||||
def select_features(self, df: pd.DataFrame, target: str, method: str = 'chi2', k: int = 10) -> pd.DataFrame:
|
||||
"""特征选择"""
|
||||
try:
|
||||
X = df.drop(columns=[target])
|
||||
y = df[target]
|
||||
|
||||
if method == 'chi2':
|
||||
# 要求输入x不能为负的
|
||||
selector = SelectKBest(chi2, k=k)
|
||||
elif method == 'f_classif':
|
||||
selector = SelectKBest(f_classif, k=k)
|
||||
else:
|
||||
raise ValueError(f"Unsupported feature selection method: {method}")
|
||||
|
||||
X_selected = selector.fit_transform(X, y)
|
||||
selected_features = X.columns[selector.get_support()].tolist()
|
||||
|
||||
self.logger.info(f"Selected {len(selected_features)} features")
|
||||
return df[selected_features + [target]]
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error selecting features: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_datetime_features(self, df: pd.DataFrame, datetime_column: str) -> pd.DataFrame:
|
||||
"""创建时间特征"""
|
||||
try:
|
||||
df[datetime_column] = pd.to_datetime(df[datetime_column])
|
||||
df[f'{datetime_column}_year'] = df[datetime_column].dt.year
|
||||
df[f'{datetime_column}_month'] = df[datetime_column].dt.month
|
||||
df[f'{datetime_column}_day'] = df[datetime_column].dt.day
|
||||
df[f'{datetime_column}_weekday'] = df[datetime_column].dt.weekday
|
||||
df[f'{datetime_column}_is_weekend'] = df[datetime_column].dt.weekday.isin([5, 6])
|
||||
|
||||
self.logger.info(f"Created datetime features from {datetime_column}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error creating datetime features: {str(e)}")
|
||||
raise
|
||||
@ -1,34 +0,0 @@
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from data_processor import DataProcessor
|
||||
from data_cleaner import DataCleaner
|
||||
from feature_engineer import FeatureEngineer
|
||||
from data_splitter import DataSplitter
|
||||
|
||||
class TestDataProcessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建测试数据
|
||||
self.test_data = pd.DataFrame({
|
||||
'feature1': [1, 2, np.nan, 4, 5],
|
||||
'feature2': ['A', 'B', 'A', 'B', 'C'],
|
||||
'target': [0, 1, 0, 1, 0]
|
||||
})
|
||||
|
||||
def test_data_cleaner(self):
|
||||
cleaner = DataCleaner()
|
||||
cleaned_data = cleaner.handle_missing_values(self.test_data.copy())
|
||||
self.assertFalse(cleaned_data.isnull().any().any())
|
||||
|
||||
def test_feature_engineer(self):
|
||||
engineer = FeatureEngineer()
|
||||
scaled_data = engineer.scale_features(self.test_data.copy())
|
||||
self.assertTrue('feature1' in scaled_data.columns)
|
||||
|
||||
def test_data_splitter(self):
|
||||
splitter = DataSplitter()
|
||||
train, val, test = splitter.train_val_test_split(self.test_data.copy(), 'target')
|
||||
self.assertEqual(len(train) + len(val) + len(test), len(self.test_data))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1,2 +0,0 @@
|
||||
from data_processor import DataProcessor
|
||||
|
||||
149
doc/接口文档code.md
149
doc/接口文档code.md
@ -914,6 +914,155 @@ MLPlatform/
|
||||
- 资源使用预警
|
||||
- 自动清理机制
|
||||
|
||||
## 5. 前端设计
|
||||
### 5.1 技术栈
|
||||
- Vue3: 前端框架
|
||||
- TypeScript: 编程语言
|
||||
- Element Plus: UI组件库
|
||||
- Axios: HTTP请求库
|
||||
- ECharts: 数据可视化库
|
||||
- Pinia: 状态管理
|
||||
- Vue Router: 路由管理
|
||||
|
||||
### 5.2 目录结构
|
||||
```
|
||||
frontend/
|
||||
├── src/
|
||||
│ ├── api/ # API接口封装
|
||||
│ │ ├── data.ts # 数据处理相关接口
|
||||
│ │ ├── model.ts # 模型管理相关接口
|
||||
│ │ └── system.ts # 系统监控相关接口
|
||||
│ ├── components/ # 公共组件
|
||||
│ │ ├── DataTable/ # 数据表格组件
|
||||
│ │ ├── ModelCard/ # 模型卡片组件
|
||||
│ │ └── Charts/ # 图表组件
|
||||
│ ├── views/ # 页面组件
|
||||
│ │ ├── data/ # 数据处理相关页面
|
||||
│ │ ├── model/ # 模型管理相关页面
|
||||
│ │ └── system/ # 系统监控相关页面
|
||||
│ ├── store/ # 状态管理
|
||||
│ ├── router/ # 路由配置
|
||||
│ └── utils/ # 工具函数
|
||||
├── public/ # 静态资源
|
||||
└── package.json # 项目配置
|
||||
```
|
||||
|
||||
### 5.3 页面设计
|
||||
1. 数据处理模块
|
||||
- 数据集列表页
|
||||
- 展示所有可用数据集
|
||||
- 支持数据集预览和基本统计信息
|
||||
- 数据集处理状态追踪
|
||||
- 数据预处理页
|
||||
- 预处理方法选择和配置
|
||||
- 参数可视化调整
|
||||
- 处理进度实时展示
|
||||
- 特征工程页
|
||||
- 特征工程方法选择
|
||||
- 特征重要性可视化
|
||||
- 数据分布展示
|
||||
|
||||
2. 模型管理模块
|
||||
- 模型列表页
|
||||
- 展示可用算法和模型
|
||||
- 模型详细信息查看
|
||||
- 模型对比功能
|
||||
- 模型训练页
|
||||
- 训练参数配置
|
||||
- 训练过程监控
|
||||
- 训练结果可视化
|
||||
- 模型评估页
|
||||
- 多指标评估结果
|
||||
- 预测结果分析
|
||||
- 模型解释性展示
|
||||
|
||||
3. 系统监控模块
|
||||
- 资源监控页
|
||||
- CPU/GPU使用率图表
|
||||
- 内存使用情况
|
||||
- 系统负载监控
|
||||
- 训练历史页
|
||||
- 实验记录列表
|
||||
- 训练详情查看
|
||||
- 实验对比分析
|
||||
- 日志查看页
|
||||
- 日志实时展示
|
||||
- 日志级别筛选
|
||||
- 日志搜索功能
|
||||
|
||||
### 5.4 交互设计
|
||||
1. 数据处理流程
|
||||
```mermaid
|
||||
graph LR
|
||||
A[上传数据] --> B[数据预览]
|
||||
B --> C[预处理配置]
|
||||
C --> D[特征工程]
|
||||
D --> E[数据划分]
|
||||
E --> F[处理完成]
|
||||
```
|
||||
|
||||
2. 模型训练流程
|
||||
```mermaid
|
||||
graph LR
|
||||
A[选择数据] --> B[选择算法]
|
||||
B --> C[参数配置]
|
||||
C --> D[开始训练]
|
||||
D --> E[监控进度]
|
||||
E --> F[查看结果]
|
||||
```
|
||||
|
||||
### 5.5 组件设计
|
||||
1. 通用组件
|
||||
- 数据表格组件
|
||||
- 图表展示组件
|
||||
- 参数配置表单
|
||||
- 进度展示组件
|
||||
- 文件上传组件
|
||||
|
||||
2. 业务组件
|
||||
- 数据预处理配置组件
|
||||
- 模型训练配置组件
|
||||
- 评估结果展示组件
|
||||
- 系统监控面板组件
|
||||
|
||||
### 5.6 状态管理
|
||||
1. 全局状态
|
||||
- 用户配置信息
|
||||
- 系统运行状态
|
||||
- 全局加载状态
|
||||
|
||||
2. 模块状态
|
||||
- 数据处理状态
|
||||
- 模型训练状态
|
||||
- 系统监控数据
|
||||
|
||||
### 5.7 性能优化
|
||||
1. 数据处理
|
||||
- 大数据分页加载
|
||||
- 数据缓存机制
|
||||
- 延迟加载策略
|
||||
|
||||
2. 交互优化
|
||||
- 防抖和节流
|
||||
- 骨架屏加载
|
||||
- 虚拟滚动列表
|
||||
|
||||
3. 可视化优化
|
||||
- 图表按需渲染
|
||||
- 数据分片处理
|
||||
- WebWorker处理大数据
|
||||
|
||||
### 5.8 错误处理
|
||||
1. 全局错误处理
|
||||
- API请求错误
|
||||
- 组件渲染错误
|
||||
- 路由错误处理
|
||||
|
||||
2. 用户提示
|
||||
- 操作成功提示
|
||||
- 错误信息展示
|
||||
- 加载状态反馈
|
||||
|
||||
## 附录A:方法详细说明
|
||||
|
||||
### A1. 数据预处理方法
|
||||
|
||||
@ -1,93 +0,0 @@
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from function.data_processor_date import DataProcessor
|
||||
|
||||
class TestDataProcessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = DataProcessor()
|
||||
|
||||
# 创建测试数据
|
||||
self.test_data = pd.DataFrame({
|
||||
'feature1': [1, 2, np.nan, 4, 5],
|
||||
'feature2': [10, 20, 30, 40, 50],
|
||||
'target': [0, 1, 0, 1, 0]
|
||||
})
|
||||
|
||||
# 保存测试数据
|
||||
self.input_path = 'dataset/dataset_raw/test_data.csv'
|
||||
Path(self.input_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.test_data.to_csv(self.input_path, index=False)
|
||||
|
||||
# 设置输出目录
|
||||
self.output_dir = 'dataset/dataset_processed'
|
||||
|
||||
def test_process_dataset(self):
|
||||
# 定义处理方法
|
||||
cleaning_methods = [
|
||||
{
|
||||
'method_name': 'SimpleImputer',
|
||||
'params': {'strategy': 'mean'}
|
||||
}
|
||||
]
|
||||
|
||||
feature_methods = [
|
||||
{
|
||||
'method_name': 'StandardScaler',
|
||||
'params': {}
|
||||
}
|
||||
]
|
||||
|
||||
split_params = {
|
||||
'test_size': 0.2,
|
||||
'val_size': 0.2
|
||||
}
|
||||
|
||||
# 处理数据集
|
||||
result = self.processor.process_dataset(
|
||||
self.input_path,
|
||||
self.output_dir,
|
||||
cleaning_methods,
|
||||
feature_methods,
|
||||
split_params
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIn('process_record', result)
|
||||
|
||||
# 验证输出文件
|
||||
record = result['process_record']
|
||||
self.assertTrue(Path(record['output_files']['train']).exists())
|
||||
self.assertTrue(Path(record['output_files']['validation']).exists())
|
||||
self.assertTrue(Path(record['output_files']['test']).exists())
|
||||
|
||||
def test_invalid_method(self):
|
||||
# 测试无效的方法名
|
||||
cleaning_methods = [
|
||||
{
|
||||
'method_name': 'InvalidMethod',
|
||||
'params': {}
|
||||
}
|
||||
]
|
||||
|
||||
result = self.processor.process_dataset(
|
||||
self.input_path,
|
||||
self.output_dir,
|
||||
cleaning_methods,
|
||||
[],
|
||||
{'test_size': 0.2, 'val_size': 0.2}
|
||||
)
|
||||
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
def tearDown(self):
|
||||
# 清理测试文件
|
||||
try:
|
||||
Path(self.input_path).unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1,49 +0,0 @@
|
||||
import unittest
|
||||
from function.method_reader_date_process import MethodReader
|
||||
|
||||
class TestMethodReader(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reader = MethodReader()
|
||||
|
||||
def test_get_preprocessing_methods(self):
|
||||
result = self.reader.get_preprocessing_methods()
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIsInstance(result['methods'], list)
|
||||
|
||||
# 检查返回的方法列表
|
||||
methods = result['methods']
|
||||
self.assertTrue(any(m['name'] == 'data_scaler' for m in methods))
|
||||
self.assertTrue(any(m['name'] == 'missing_value_handler' for m in methods))
|
||||
self.assertTrue(any(m['name'] == 'outlier_detector' for m in methods))
|
||||
|
||||
def test_get_method_details(self):
|
||||
# 测试获取StandardScaler的详细信息
|
||||
result = self.reader.get_method_details('StandardScaler')
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertEqual(result['method']['name'], 'StandardScaler')
|
||||
|
||||
# 检查返回的详细信息字段
|
||||
method = result['method']
|
||||
self.assertIn('description', method)
|
||||
self.assertIn('principle', method)
|
||||
self.assertIn('advantages', method)
|
||||
self.assertIn('disadvantages', method)
|
||||
self.assertIn('applicable_scenarios', method)
|
||||
self.assertIn('parameters', method)
|
||||
|
||||
# 检查参数信息
|
||||
parameters = method['parameters']
|
||||
self.assertIsInstance(parameters, list)
|
||||
if parameters:
|
||||
param = parameters[0]
|
||||
self.assertIn('name', param)
|
||||
self.assertIn('type', param)
|
||||
self.assertIn('default', param)
|
||||
self.assertIn('description', param)
|
||||
|
||||
# 测试获取不存在的方法
|
||||
result = self.reader.get_method_details('NonExistentMethod')
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1,99 +0,0 @@
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import mlflow
|
||||
from pathlib import Path
|
||||
from function.model_manager import ModelManager
|
||||
|
||||
class TestModelManager:
|
||||
@pytest.fixture
|
||||
def model_manager(self):
|
||||
return ModelManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
n_samples = 100
|
||||
X = np.random.randn(n_samples, 4)
|
||||
y = (X[:, 0] + X[:, 1] > 0).astype(int)
|
||||
|
||||
# 保存测试数据
|
||||
data_dir = Path("dataset/dataset_processed/test_data")
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(4)])
|
||||
df['label'] = y
|
||||
|
||||
data_path = data_dir / "test_data.csv"
|
||||
df.to_csv(data_path, index=False)
|
||||
|
||||
return str(data_path)
|
||||
|
||||
@pytest.fixture
|
||||
def trained_model(self, sample_data):
|
||||
# 训练一个简单的模型用于测试
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
# 加载数据
|
||||
data = pd.read_csv(sample_data)
|
||||
X = data.drop('label', axis=1).values
|
||||
y = data['label'].values
|
||||
|
||||
# 训练模型
|
||||
model = RandomForestClassifier(n_estimators=10, random_state=42)
|
||||
model.fit(X, y)
|
||||
|
||||
# 使用MLflow记录模型
|
||||
with mlflow.start_run() as run:
|
||||
mlflow.sklearn.log_model(model, "model")
|
||||
mlflow.log_param("algorithm", "RandomForestClassifier")
|
||||
|
||||
return run.info.run_id
|
||||
|
||||
def test_predict(self, model_manager, sample_data, trained_model):
|
||||
# 设置输出路径
|
||||
output_dir = Path("predictions/test")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "test_predictions.csv")
|
||||
|
||||
# 执行预测
|
||||
result = model_manager.predict(
|
||||
run_id=trained_model,
|
||||
data_path=sample_data,
|
||||
output_path=output_path,
|
||||
metrics=['accuracy', 'f1']
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
assert result['status'] == 'success'
|
||||
assert 'prediction' in result
|
||||
assert Path(result['prediction']['output_file']).exists()
|
||||
assert result['prediction']['samples_count'] == 100
|
||||
assert 'accuracy' in result['prediction']['metrics']
|
||||
assert 'f1' in result['prediction']['metrics']
|
||||
|
||||
# 验证预测结果格式
|
||||
predictions = pd.read_csv(output_path)
|
||||
assert 'prediction' in predictions.columns
|
||||
assert len(predictions) == 100
|
||||
|
||||
def test_predict_invalid_run_id(self, model_manager, sample_data):
|
||||
result = model_manager.predict(
|
||||
run_id="invalid_run_id",
|
||||
data_path=sample_data,
|
||||
output_path="predictions/test/invalid.csv"
|
||||
)
|
||||
|
||||
assert result['status'] == 'error'
|
||||
assert '未找到运行ID' in result['message']
|
||||
|
||||
def test_predict_invalid_data_path(self, model_manager, trained_model):
|
||||
result = model_manager.predict(
|
||||
run_id=trained_model,
|
||||
data_path="invalid/path/data.csv",
|
||||
output_path="predictions/test/invalid.csv"
|
||||
)
|
||||
|
||||
assert result['status'] == 'error'
|
||||
assert '数据加载失败' in result['message']
|
||||
@ -1,85 +0,0 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from function.model_trainer import ModelTrainer
|
||||
|
||||
class TestModelTrainer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.trainer = ModelTrainer()
|
||||
|
||||
# 创建测试数据
|
||||
np.random.seed(42)
|
||||
self.X_train = np.random.randn(100, 5)
|
||||
self.y_train = np.random.randint(0, 2, 100)
|
||||
self.X_val = np.random.randn(30, 5)
|
||||
self.y_val = np.random.randint(0, 2, 30)
|
||||
|
||||
def test_train_model(self):
|
||||
# 准备训练数据
|
||||
train_data = {
|
||||
'features': self.X_train,
|
||||
'labels': self.y_train
|
||||
}
|
||||
|
||||
val_data = {
|
||||
'features': self.X_val,
|
||||
'labels': self.y_val
|
||||
}
|
||||
|
||||
# 模型配置
|
||||
model_config = {
|
||||
'algorithm': 'LogisticRegression',
|
||||
'task_type': 'classification',
|
||||
'params': {
|
||||
'random_state': 42
|
||||
}
|
||||
}
|
||||
|
||||
# 训练模型
|
||||
result = self.trainer.train_model(
|
||||
train_data,
|
||||
val_data,
|
||||
model_config,
|
||||
'test_experiment'
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
self.assertEqual(result['status'], 'success')
|
||||
self.assertIn('run_id', result)
|
||||
self.assertIn('metrics', result)
|
||||
|
||||
# 验证指标
|
||||
metrics = result['metrics']
|
||||
self.assertIn('accuracy', metrics)
|
||||
self.assertIn('precision', metrics)
|
||||
self.assertIn('recall', metrics)
|
||||
self.assertIn('f1', metrics)
|
||||
|
||||
def test_invalid_algorithm(self):
|
||||
# 测试无效的算法名
|
||||
train_data = {
|
||||
'features': self.X_train,
|
||||
'labels': self.y_train
|
||||
}
|
||||
|
||||
val_data = {
|
||||
'features': self.X_val,
|
||||
'labels': self.y_val
|
||||
}
|
||||
|
||||
model_config = {
|
||||
'algorithm': 'InvalidAlgorithm',
|
||||
'task_type': 'classification',
|
||||
'params': {}
|
||||
}
|
||||
|
||||
result = self.trainer.train_model(
|
||||
train_data,
|
||||
val_data,
|
||||
model_config,
|
||||
'test_experiment'
|
||||
)
|
||||
|
||||
self.assertEqual(result['status'], 'error')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -1,86 +0,0 @@
|
||||
import pytest
|
||||
from function.system_monitor import SystemMonitor
|
||||
from typing import Dict
|
||||
|
||||
class TestSystemMonitor:
|
||||
@pytest.fixture
|
||||
def system_monitor(self):
|
||||
return SystemMonitor()
|
||||
|
||||
def test_get_system_resources(self, system_monitor):
|
||||
"""测试获取系统资源信息"""
|
||||
result = system_monitor.get_system_resources()
|
||||
|
||||
# 验证返回格式
|
||||
assert isinstance(result, dict)
|
||||
assert 'status' in result
|
||||
assert result['status'] == 'success'
|
||||
assert 'resources' in result
|
||||
assert 'timestamp' in result
|
||||
|
||||
resources = result['resources']
|
||||
|
||||
# 验证GPU信息
|
||||
assert 'gpu' in resources
|
||||
if resources['gpu']: # 如果有GPU
|
||||
gpu = resources['gpu'][0]
|
||||
assert 'id' in gpu
|
||||
assert 'name' in gpu
|
||||
assert 'memory' in gpu
|
||||
assert 'utilization' in gpu
|
||||
assert 'temperature' in gpu
|
||||
|
||||
# 验证CPU信息
|
||||
assert 'cpu' in resources
|
||||
cpu = resources['cpu']
|
||||
assert 'count' in cpu
|
||||
assert 'utilization' in cpu
|
||||
assert 'memory' in cpu
|
||||
assert 'swap' in cpu
|
||||
|
||||
# 验证内存信息
|
||||
memory = cpu['memory']
|
||||
assert 'total' in memory
|
||||
assert 'used' in memory
|
||||
assert 'free' in memory
|
||||
assert 'percent' in memory
|
||||
assert memory['total'] > 0
|
||||
assert 0 <= memory['percent'] <= 100
|
||||
|
||||
# 验证磁盘信息
|
||||
assert 'disk' in resources
|
||||
assert len(resources['disk']) > 0
|
||||
for mount_point, disk_info in resources['disk'].items():
|
||||
assert 'total' in disk_info
|
||||
assert 'used' in disk_info
|
||||
assert 'free' in disk_info
|
||||
assert 'percent' in disk_info
|
||||
assert disk_info['total'] > 0
|
||||
assert 0 <= disk_info['percent'] <= 100
|
||||
|
||||
# 验证进程信息
|
||||
assert 'processes' in resources
|
||||
processes = resources['processes']
|
||||
assert 'total' in processes
|
||||
assert 'running' in processes
|
||||
assert 'sleeping' in processes
|
||||
assert processes['total'] > 0
|
||||
assert processes['running'] >= 0
|
||||
assert processes['sleeping'] >= 0
|
||||
|
||||
def test_error_handling(self, system_monitor, monkeypatch):
|
||||
"""测试错误处理"""
|
||||
def mock_gpu_error(*args, **kwargs):
|
||||
raise Exception("GPU query failed")
|
||||
|
||||
# 模拟GPU查询错误
|
||||
monkeypatch.setattr(system_monitor, '_get_gpu_info', mock_gpu_error)
|
||||
|
||||
result = system_monitor.get_system_resources()
|
||||
assert result['status'] == 'success' # 即使GPU查询失败,其他资源信息仍应返回
|
||||
assert result['resources']['gpu'] == [] # GPU信息应为空列表
|
||||
|
||||
# 验证其他资源信息仍然可用
|
||||
assert 'cpu' in result['resources']
|
||||
assert 'disk' in result['resources']
|
||||
assert 'processes' in result['resources']
|
||||
@ -1,6 +0,0 @@
|
||||
import pandas as pd
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
cancer = load_breast_cancer()
|
||||
df = pd.DataFrame(cancer.data, columns=cancer.feature_names)
|
||||
df['target'] = cancer.target
|
||||
df.to_csv('./dataset/breast_cancer.csv', index=False)
|
||||
Loading…
Reference in New Issue
Block a user