import pytest import pandas as pd import numpy as np import mlflow from pathlib import Path from function.model_manager import ModelManager class TestModelManager: @pytest.fixture def model_manager(self): return ModelManager() @pytest.fixture def sample_data(self): # 创建测试数据 np.random.seed(42) n_samples = 100 X = np.random.randn(n_samples, 4) y = (X[:, 0] + X[:, 1] > 0).astype(int) # 保存测试数据 data_dir = Path("dataset/dataset_processed/test_data") data_dir.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(4)]) df['label'] = y data_path = data_dir / "test_data.csv" df.to_csv(data_path, index=False) return str(data_path) @pytest.fixture def trained_model(self, sample_data): # 训练一个简单的模型用于测试 from sklearn.ensemble import RandomForestClassifier # 加载数据 data = pd.read_csv(sample_data) X = data.drop('label', axis=1).values y = data['label'].values # 训练模型 model = RandomForestClassifier(n_estimators=10, random_state=42) model.fit(X, y) # 使用MLflow记录模型 with mlflow.start_run() as run: mlflow.sklearn.log_model(model, "model") mlflow.log_param("algorithm", "RandomForestClassifier") return run.info.run_id def test_predict(self, model_manager, sample_data, trained_model): # 设置输出路径 output_dir = Path("predictions/test") output_dir.mkdir(parents=True, exist_ok=True) output_path = str(output_dir / "test_predictions.csv") # 执行预测 result = model_manager.predict( run_id=trained_model, data_path=sample_data, output_path=output_path, metrics=['accuracy', 'f1'] ) # 验证结果 assert result['status'] == 'success' assert 'prediction' in result assert Path(result['prediction']['output_file']).exists() assert result['prediction']['samples_count'] == 100 assert 'accuracy' in result['prediction']['metrics'] assert 'f1' in result['prediction']['metrics'] # 验证预测结果格式 predictions = pd.read_csv(output_path) assert 'prediction' in predictions.columns assert len(predictions) == 100 def test_predict_invalid_run_id(self, model_manager, sample_data): result = model_manager.predict( run_id="invalid_run_id", data_path=sample_data, output_path="predictions/test/invalid.csv" ) assert result['status'] == 'error' assert '未找到运行ID' in result['message'] def test_predict_invalid_data_path(self, model_manager, trained_model): result = model_manager.predict( run_id=trained_model, data_path="invalid/path/data.csv", output_path="predictions/test/invalid.csv" ) assert result['status'] == 'error' assert '数据加载失败' in result['message']