50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
from pathlib import Path
|
|
|
|
from src.demo_service import DemoModelService
|
|
|
|
|
|
def test_demo_service_loads_local_dataset():
|
|
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
|
|
|
summary = service.get_dataset_summary()
|
|
|
|
assert summary["row_count"] >= 20
|
|
assert "actual_cost" in summary["columns"]
|
|
assert summary["target"] == "actual_cost"
|
|
assert summary["preview"][0]["name"]
|
|
assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"}
|
|
|
|
|
|
def test_demo_service_returns_chinese_algorithm_names_with_english_notes():
|
|
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
|
|
|
algorithms = service.get_algorithms()
|
|
|
|
linear = next(item for item in algorithms if item["key"] == "linear")
|
|
assert linear["name"] == "线性回归"
|
|
assert linear["english_name"] == "Linear Regression"
|
|
assert linear["family"] == "线性模型"
|
|
|
|
|
|
def test_demo_service_runs_multiple_algorithms():
|
|
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
|
|
|
result = service.run_demo(["linear", "random_forest", "gradient_boosting"])
|
|
|
|
assert result["source"] == "local-file"
|
|
assert result["best_model"] in result["metrics"]
|
|
assert len(result["metrics"]) == 3
|
|
assert len(result["prediction_points"]) > 0
|
|
assert len(result["sample_prediction"]["predictions"]) == 3
|
|
for metrics in result["metrics"].values():
|
|
assert {"r2", "mae", "rmse"}.issubset(metrics)
|
|
|
|
|
|
def test_demo_service_ignores_unavailable_algorithms():
|
|
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
|
|
|
result = service.run_demo(["linear", "does_not_exist"])
|
|
|
|
assert list(result["metrics"].keys()) == ["linear"]
|
|
assert result["warnings"]
|