CostPrediction/tests/test_demo_service.py

50 lines
1.7 KiB
Python

from pathlib import Path
from src.demo_service import DemoModelService
def test_demo_service_loads_local_dataset():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
summary = service.get_dataset_summary()
assert summary["row_count"] >= 20
assert "actual_cost" in summary["columns"]
assert summary["target"] == "actual_cost"
assert summary["preview"][0]["name"]
assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"}
def test_demo_service_returns_chinese_algorithm_names_with_english_notes():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
algorithms = service.get_algorithms()
linear = next(item for item in algorithms if item["key"] == "linear")
assert linear["name"] == "线性回归"
assert linear["english_name"] == "Linear Regression"
assert linear["family"] == "线性模型"
def test_demo_service_runs_multiple_algorithms():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
result = service.run_demo(["linear", "random_forest", "gradient_boosting"])
assert result["source"] == "local-file"
assert result["best_model"] in result["metrics"]
assert len(result["metrics"]) == 3
assert len(result["prediction_points"]) > 0
assert len(result["sample_prediction"]["predictions"]) == 3
for metrics in result["metrics"].values():
assert {"r2", "mae", "rmse"}.issubset(metrics)
def test_demo_service_ignores_unavailable_algorithms():
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
result = service.run_demo(["linear", "does_not_exist"])
assert list(result["metrics"].keys()) == ["linear"]
assert result["warnings"]