MLPlatform/function_old/test_model_manager.py

99 lines
3.3 KiB
Python

import pytest
import pandas as pd
import numpy as np
import mlflow
from pathlib import Path
from function.model_manager import ModelManager
class TestModelManager:
@pytest.fixture
def model_manager(self):
return ModelManager()
@pytest.fixture
def sample_data(self):
# 创建测试数据
np.random.seed(42)
n_samples = 100
X = np.random.randn(n_samples, 4)
y = (X[:, 0] + X[:, 1] > 0).astype(int)
# 保存测试数据
data_dir = Path("dataset/dataset_processed/test_data")
data_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(4)])
df['label'] = y
data_path = data_dir / "test_data.csv"
df.to_csv(data_path, index=False)
return str(data_path)
@pytest.fixture
def trained_model(self, sample_data):
# 训练一个简单的模型用于测试
from sklearn.ensemble import RandomForestClassifier
# 加载数据
data = pd.read_csv(sample_data)
X = data.drop('label', axis=1).values
y = data['label'].values
# 训练模型
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(X, y)
# 使用MLflow记录模型
with mlflow.start_run() as run:
mlflow.sklearn.log_model(model, "model")
mlflow.log_param("algorithm", "RandomForestClassifier")
return run.info.run_id
def test_predict(self, model_manager, sample_data, trained_model):
# 设置输出路径
output_dir = Path("predictions/test")
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "test_predictions.csv")
# 执行预测
result = model_manager.predict(
run_id=trained_model,
data_path=sample_data,
output_path=output_path,
metrics=['accuracy', 'f1']
)
# 验证结果
assert result['status'] == 'success'
assert 'prediction' in result
assert Path(result['prediction']['output_file']).exists()
assert result['prediction']['samples_count'] == 100
assert 'accuracy' in result['prediction']['metrics']
assert 'f1' in result['prediction']['metrics']
# 验证预测结果格式
predictions = pd.read_csv(output_path)
assert 'prediction' in predictions.columns
assert len(predictions) == 100
def test_predict_invalid_run_id(self, model_manager, sample_data):
result = model_manager.predict(
run_id="invalid_run_id",
data_path=sample_data,
output_path="predictions/test/invalid.csv"
)
assert result['status'] == 'error'
assert '未找到运行ID' in result['message']
def test_predict_invalid_data_path(self, model_manager, trained_model):
result = model_manager.predict(
run_id=trained_model,
data_path="invalid/path/data.csv",
output_path="predictions/test/invalid.csv"
)
assert result['status'] == 'error'
assert '数据加载失败' in result['message']