99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import pytest
|
|
import pandas as pd
|
|
import numpy as np
|
|
import mlflow
|
|
from pathlib import Path
|
|
from function.model_manager import ModelManager
|
|
|
|
class TestModelManager:
|
|
@pytest.fixture
|
|
def model_manager(self):
|
|
return ModelManager()
|
|
|
|
@pytest.fixture
|
|
def sample_data(self):
|
|
# 创建测试数据
|
|
np.random.seed(42)
|
|
n_samples = 100
|
|
X = np.random.randn(n_samples, 4)
|
|
y = (X[:, 0] + X[:, 1] > 0).astype(int)
|
|
|
|
# 保存测试数据
|
|
data_dir = Path("dataset/dataset_processed/test_data")
|
|
data_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(4)])
|
|
df['label'] = y
|
|
|
|
data_path = data_dir / "test_data.csv"
|
|
df.to_csv(data_path, index=False)
|
|
|
|
return str(data_path)
|
|
|
|
@pytest.fixture
|
|
def trained_model(self, sample_data):
|
|
# 训练一个简单的模型用于测试
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
# 加载数据
|
|
data = pd.read_csv(sample_data)
|
|
X = data.drop('label', axis=1).values
|
|
y = data['label'].values
|
|
|
|
# 训练模型
|
|
model = RandomForestClassifier(n_estimators=10, random_state=42)
|
|
model.fit(X, y)
|
|
|
|
# 使用MLflow记录模型
|
|
with mlflow.start_run() as run:
|
|
mlflow.sklearn.log_model(model, "model")
|
|
mlflow.log_param("algorithm", "RandomForestClassifier")
|
|
|
|
return run.info.run_id
|
|
|
|
def test_predict(self, model_manager, sample_data, trained_model):
|
|
# 设置输出路径
|
|
output_dir = Path("predictions/test")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
output_path = str(output_dir / "test_predictions.csv")
|
|
|
|
# 执行预测
|
|
result = model_manager.predict(
|
|
run_id=trained_model,
|
|
data_path=sample_data,
|
|
output_path=output_path,
|
|
metrics=['accuracy', 'f1']
|
|
)
|
|
|
|
# 验证结果
|
|
assert result['status'] == 'success'
|
|
assert 'prediction' in result
|
|
assert Path(result['prediction']['output_file']).exists()
|
|
assert result['prediction']['samples_count'] == 100
|
|
assert 'accuracy' in result['prediction']['metrics']
|
|
assert 'f1' in result['prediction']['metrics']
|
|
|
|
# 验证预测结果格式
|
|
predictions = pd.read_csv(output_path)
|
|
assert 'prediction' in predictions.columns
|
|
assert len(predictions) == 100
|
|
|
|
def test_predict_invalid_run_id(self, model_manager, sample_data):
|
|
result = model_manager.predict(
|
|
run_id="invalid_run_id",
|
|
data_path=sample_data,
|
|
output_path="predictions/test/invalid.csv"
|
|
)
|
|
|
|
assert result['status'] == 'error'
|
|
assert '未找到运行ID' in result['message']
|
|
|
|
def test_predict_invalid_data_path(self, model_manager, trained_model):
|
|
result = model_manager.predict(
|
|
run_id=trained_model,
|
|
data_path="invalid/path/data.csv",
|
|
output_path="predictions/test/invalid.csv"
|
|
)
|
|
|
|
assert result['status'] == 'error'
|
|
assert '数据加载失败' in result['message'] |