diff --git a/.gitignore b/.gitignore index 468c727..45e61c8 100644 --- a/.gitignore +++ b/.gitignore @@ -42,7 +42,8 @@ node_modules /models /logs /uploads -/data +/data/* +!/data/demo_equipment_costs.csv # local env files .env.local @@ -65,8 +66,9 @@ pnpm-debug.log* *.sw? /frontend/node_modules /frontend/dist +/release/ *.zip *.tar *.gz -*.whl \ No newline at end of file +*.whl diff --git a/data/demo_equipment_costs.csv b/data/demo_equipment_costs.csv new file mode 100644 index 0000000..2b9347c --- /dev/null +++ b/data/demo_equipment_costs.csv @@ -0,0 +1,27 @@ +name,type,length_m,width_m,height_m,weight_kg,max_range_km,payload_kg,max_speed_kmh,endurance_min,tech_level,scale_level,supply_chain_level,complexity_score,actual_cost +隼击-A,巡飞弹,1.2,1.8,0.32,18,35,4,145,55,6.4,5.8,6.2,5.9,420000 +隼击-B,巡飞弹,1.5,2.1,0.36,26,48,6,160,70,6.8,6.2,6.4,6.6,610000 +隼击-C,巡飞弹,1.8,2.5,0.42,34,65,8,175,85,7.2,6.4,6.8,7.1,830000 +侦察-100,巡飞弹,0.9,1.4,0.25,9,18,2,110,35,5.4,5.1,5.5,4.8,190000 +侦察-200,巡飞弹,1.1,1.7,0.29,14,28,3,125,48,5.9,5.4,5.7,5.3,310000 +侦察-300,巡飞弹,1.4,2.0,0.34,22,42,5,150,62,6.3,5.9,6.0,6.1,520000 +锐蛇-S,巡飞弹,1.7,2.4,0.38,30,58,7,185,76,7.5,6.7,6.9,7.4,940000 +锐蛇-M,巡飞弹,2.0,2.8,0.46,44,82,10,205,94,8.0,7.1,7.3,8.0,1360000 +锐蛇-L,巡飞弹,2.4,3.2,0.55,62,120,15,230,125,8.7,7.5,7.8,8.8,2100000 +鹰眼-1,巡飞弹,1.3,1.9,0.31,20,40,4,155,58,6.6,5.7,6.3,6.0,470000 +鹰眼-2,巡飞弹,1.6,2.2,0.37,29,57,7,172,78,7.1,6.1,6.6,6.9,760000 +鹰眼-3,巡飞弹,2.1,2.9,0.49,51,95,12,215,105,8.2,7.0,7.2,8.1,1580000 +雷霆-122,火箭炮,6.9,2.4,2.8,13500,22,480,72,0,5.8,6.6,6.0,5.5,980000 +雷霆-160,火箭炮,7.6,2.6,3.0,16800,40,760,68,0,6.4,6.9,6.3,6.1,1450000 +雷霆-220,火箭炮,8.3,2.8,3.2,21500,70,1200,65,0,7.0,7.1,6.8,7.0,2380000 +雷霆-300,火箭炮,9.8,3.0,3.4,28500,120,1850,62,0,7.8,7.4,7.2,8.0,4200000 +山猫-95,火箭炮,6.2,2.3,2.7,11800,18,360,78,0,5.4,6.0,5.7,5.0,740000 +山猫-120,火箭炮,6.7,2.4,2.8,13000,30,520,75,0,5.9,6.2,6.0,5.6,1050000 +山猫-200,火箭炮,7.9,2.7,3.1,19800,60,980,70,0,6.8,6.8,6.5,6.7,1980000 +山猫-300,火箭炮,9.3,2.9,3.3,26000,105,1600,66,0,7.6,7.2,7.0,7.8,3560000 +弓兵-L,火箭炮,8.8,2.9,3.2,23500,85,1350,69,0,7.2,7.0,6.9,7.3,2860000 +弓兵-X,火箭炮,10.2,3.1,3.6,31000,150,2100,60,0,8.4,7.8,7.6,8.7,5400000 +长矛-1,火箭炮,7.1,2.5,2.9,14200,28,560,73,0,6.1,6.4,6.1,5.8,1180000 +长矛-2,火箭炮,8.1,2.7,3.1,20500,75,1120,68,0,7.1,6.9,6.7,7.1,2420000 +长矛-3,火箭炮,9.6,3.0,3.5,29200,130,1900,63,0,8.1,7.5,7.4,8.3,4650000 +擎天-M,火箭炮,10.8,3.2,3.8,34800,180,2450,58,0,8.9,8.0,7.9,9.2,6900000 diff --git a/demo_standalone/README.md b/demo_standalone/README.md new file mode 100644 index 0000000..a609bcb --- /dev/null +++ b/demo_standalone/README.md @@ -0,0 +1,13 @@ +# 机器学习算法演示 + +## 运行方式 + +1. 解压 zip 文件。 +2. 双击 `start_demo.bat`。 +3. 浏览器会自动打开 `http://127.0.0.1:5001/algorithm-demo`。 + +## 说明 + +- 演示使用 `data/demo_equipment_costs.csv`,不需要 MySQL。 +- 首次运行会创建 `.venv` 并安装最小 Python 依赖。 +- 需要本机已安装 Python 3.9 至 3.11。 diff --git a/demo_standalone/requirements.txt b/demo_standalone/requirements.txt new file mode 100644 index 0000000..dc85886 --- /dev/null +++ b/demo_standalone/requirements.txt @@ -0,0 +1,5 @@ +flask>=3.1.0 +flask-cors>=5.0.0 +numpy>=1.26.0,<2.0.0 +pandas>=2.2.0 +scikit-learn>=1.5.2 diff --git a/demo_standalone/server.py b/demo_standalone/server.py new file mode 100644 index 0000000..d416c7d --- /dev/null +++ b/demo_standalone/server.py @@ -0,0 +1,48 @@ +from pathlib import Path + +from flask import Flask, jsonify, request, send_from_directory +from flask_cors import CORS + +from demo_service import DemoModelService + + +BASE_DIR = Path(__file__).resolve().parent +STATIC_DIR = BASE_DIR / "frontend" +DATASET_PATH = BASE_DIR / "data" / "demo_equipment_costs.csv" + + +def create_app(): + app = Flask(__name__, static_folder=None) + CORS(app) + + @app.get("/api/demo/algorithms") + def demo_algorithms(): + service = DemoModelService(DATASET_PATH) + return jsonify({"algorithms": service.get_algorithms()}) + + @app.get("/api/demo/dataset") + def demo_dataset(): + service = DemoModelService(DATASET_PATH) + return jsonify(service.get_dataset_summary()) + + @app.post("/api/demo/run") + def demo_run(): + payload = request.get_json(silent=True) or {} + service = DemoModelService(DATASET_PATH) + return jsonify(service.run_demo(payload.get("algorithms"))) + + @app.get("/") + @app.get("/") + def frontend(path=""): + file_path = STATIC_DIR / path + if path and file_path.exists() and file_path.is_file(): + return send_from_directory(STATIC_DIR, path) + return send_from_directory(STATIC_DIR, "index.html") + + return app + + +if __name__ == "__main__": + app = create_app() + print("算法演示服务已启动:http://127.0.0.1:5001/algorithm-demo") + app.run(host="127.0.0.1", port=5001, debug=False) diff --git a/demo_standalone/start_demo.bat b/demo_standalone/start_demo.bat new file mode 100644 index 0000000..c7dac6c --- /dev/null +++ b/demo_standalone/start_demo.bat @@ -0,0 +1,32 @@ +@echo off +setlocal +cd /d "%~dp0" + +where python >nul 2>nul +if errorlevel 1 ( + echo 未找到 Python。请先安装 Python 3.9 至 3.11,然后重新运行本脚本。 + pause + exit /b 1 +) + +if not exist ".venv\Scripts\python.exe" ( + echo 正在创建演示环境... + python -m venv .venv + if errorlevel 1 ( + echo 创建环境失败。 + pause + exit /b 1 + ) +) + +echo 正在安装或检查依赖... +".venv\Scripts\python.exe" -m pip install -r requirements.txt +if errorlevel 1 ( + echo 依赖安装失败,请检查网络或 Python 环境。 + pause + exit /b 1 +) + +start "" http://127.0.0.1:5001/algorithm-demo +".venv\Scripts\python.exe" server.py +pause diff --git a/docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md b/docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md new file mode 100644 index 0000000..081698b --- /dev/null +++ b/docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md @@ -0,0 +1,57 @@ +# ML Algorithm Demo Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add a modern demo page that compares common machine learning algorithms using a local data file instead of MySQL. + +**Architecture:** Add an isolated backend demo service that reads `data/demo_equipment_costs.csv`, trains selected regressors in memory, and returns metrics, prediction points, feature importance, and a sample prediction. Add a Vue route that calls the demo API and renders algorithm switching, charts, metrics, and data preview. Existing database-backed pages remain unchanged. + +**Tech Stack:** Flask, pandas, scikit-learn, optional xgboost/lightgbm, Vue 3, Element Plus, ECharts. + +--- + +### Task 1: Backend Demo Service + +**Files:** +- Create: `tests/test_demo_service.py` +- Create: `src/demo_service.py` +- Create: `data/demo_equipment_costs.csv` + +- [ ] Write failing tests for data loading, algorithm availability, and training payload shape. +- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it fails because `src.demo_service` is missing. +- [ ] Implement `DemoModelService` with local CSV loading, selected algorithm training, metric calculation, top feature importance, and fallback algorithms when optional libraries are unavailable. +- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it passes. + +### Task 2: Demo API + +**Files:** +- Modify: `src/routes.py` +- Test: `tests/test_demo_routes.py` + +- [ ] Write Flask route tests for `GET /api/demo/algorithms`, `GET /api/demo/dataset`, and `POST /api/demo/run`. +- [ ] Run `python -m pytest tests/test_demo_routes.py -q` and verify missing routes fail. +- [ ] Add demo routes that call `DemoModelService` and do not access MySQL. +- [ ] Run the route tests and demo service tests. + +### Task 3: Vue Demo Page + +**Files:** +- Create: `frontend/src/views/AlgorithmDemoPage.vue` +- Modify: `frontend/src/router/index.js` +- Modify: `frontend/src/App.vue` +- Modify: `frontend/src/api/index.js` +- Modify: `frontend/src/views/HomePage.vue` + +- [ ] Add API helpers for demo algorithms, dataset, and run. +- [ ] Add `/algorithm-demo` route and navigation label `算法演示`. +- [ ] Build a modern dashboard-style page with algorithm toggles, metric cards, comparison chart, predicted-vs-actual chart, feature importance chart, sample prediction panel, and data preview table. +- [ ] Add a home page entry that links to the demo. + +### Task 4: Verification + +**Files:** +- No new files. + +- [ ] Run `python -m pytest tests/test_demo_service.py tests/test_demo_routes.py -q`. +- [ ] Run `npm run build` in `frontend`. +- [ ] Start the app if feasible and confirm the new route is available. diff --git a/frontend/src/App.vue b/frontend/src/App.vue index c4278c5..9050921 100644 --- a/frontend/src/App.vue +++ b/frontend/src/App.vue @@ -9,6 +9,7 @@ 首页 成本预测 特征分析 + 算法演示 模型训练 模型管理 数据集管理 diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js index 1090d1a..fad3eb4 100644 --- a/frontend/src/api/index.js +++ b/frontend/src/api/index.js @@ -40,4 +40,16 @@ export const updateEquipment = (id, data) => { export const deleteEquipment = (id) => { return api.delete(`/data/${id}`) -} \ No newline at end of file +} + +export const getDemoAlgorithms = () => { + return api.get('/demo/algorithms') +} + +export const getDemoDataset = () => { + return api.get('/demo/dataset') +} + +export const runAlgorithmDemo = (data) => { + return api.post('/demo/run', data) +} diff --git a/frontend/src/config.js b/frontend/src/config.js index 0a697ee..96cc5f9 100644 --- a/frontend/src/config.js +++ b/frontend/src/config.js @@ -1,8 +1,12 @@ -export const API_BASE_URL = 'http://localhost:5001/api'; +const isLocalDevServer = window.location.port === '8080' + +export const API_BASE_URL = isLocalDevServer + ? 'http://localhost:5001/api' + : `${window.location.origin}/api`; export const DB_CONFIG = { host: 'localhost', user: 'root', password: '123456', database: 'equipment_cost_db' -}; \ No newline at end of file +}; diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js index 3dfcdee..59994af 100644 --- a/frontend/src/router/index.js +++ b/frontend/src/router/index.js @@ -5,6 +5,7 @@ import DatasetPage from '@/views/DatasetPage.vue' import PredictPage from '@/views/PredictPage.vue' import AnalysisPage from '@/views/AnalysisPage.vue' import TrainingPage from '@/views/TrainingPage.vue' +import AlgorithmDemoPage from '@/views/AlgorithmDemoPage.vue' const routes = [ { @@ -37,6 +38,11 @@ const routes = [ name: 'Training', component: TrainingPage }, + { + path: '/algorithm-demo', + name: 'AlgorithmDemo', + component: AlgorithmDemoPage + }, { path: '/models', name: 'Models', @@ -49,4 +55,4 @@ const router = createRouter({ routes }) -export default router \ No newline at end of file +export default router diff --git a/frontend/src/views/AlgorithmDemoPage.vue b/frontend/src/views/AlgorithmDemoPage.vue new file mode 100644 index 0000000..b95724a --- /dev/null +++ b/frontend/src/views/AlgorithmDemoPage.vue @@ -0,0 +1,644 @@ + + + + + diff --git a/frontend/src/views/HomePage.vue b/frontend/src/views/HomePage.vue index 60c0676..fae5d28 100644 --- a/frontend/src/views/HomePage.vue +++ b/frontend/src/views/HomePage.vue @@ -26,6 +26,13 @@

训练和优化预测模型

+ + + +

算法演示

+

切换常用机器学习算法并对比预测效果

+
+
@@ -53,7 +60,7 @@ \ No newline at end of file + diff --git a/html5_cost_prediction/README.txt b/html5_cost_prediction/README.txt new file mode 100644 index 0000000..9e5e784 --- /dev/null +++ b/html5_cost_prediction/README.txt @@ -0,0 +1,11 @@ +智能成本预测系统 - HTML5离线版 + +运行方式: +1. 解压 zip 文件。 +2. 双击 index.html。 + +说明: +- 不需要 Python。 +- 不需要数据库。 +- 不需要联网。 +- 页面内置样例数据和模型效果,用于客户现场展示不同模型的预测差异。 diff --git a/html5_cost_prediction/index.html b/html5_cost_prediction/index.html new file mode 100644 index 0000000..cb42457 --- /dev/null +++ b/html5_cost_prediction/index.html @@ -0,0 +1,1324 @@ + + + + + + 智能成本预测系统 + + + +
+
+
+ +
+ 智能成本预测系统 + 多模型融合 · 参数感知 · 决策辅助 +
+
+
+ +
+
+
+ +
AI COST INTELLIGENCE
+

智能成本预测系统

+

+ 汇集装备参数、技术成熟度、供应链能力与复杂度评分,快速切换多种智能模型, + 直观看到预测对比、误差指标和关键影响因子。 +

+
+ +
+
+ +
+ +
+ + +
+
+
+
+

核心指标

+

当前模型的预测质量与成本输出

+
+ 线性回归 +
+
+
模型综合评分0.000
+
平均绝对误差¥0
+
均方根误差¥0
+
成本区间¥0
+
+
+ +
+
+
+

预测对比

+

预测柱状图与真实成本曲线对比

+
+
+
+ +
+
+ +
+
+
+

关键因子

+

影响成本的主要参数

+
+
+
+
+ +
+
+
+

模型对比

+

多模型综合评分横向比较

+
+
+
+ +
+
+ +
+
+
+

智能洞察

+

自动生成的辅助判断

+
+
+
+
+ +
+
+
+

数据画像

+

内置样例数据用于现场离线展示

+
+
+
+ + + + + + + + + + + + + +
名称类型重量射程速度技术水平实际成本
+
+
+
+
+
+ + + + diff --git a/scripts/build_demo_zip.ps1 b/scripts/build_demo_zip.ps1 new file mode 100644 index 0000000..08cba0a --- /dev/null +++ b/scripts/build_demo_zip.ps1 @@ -0,0 +1,57 @@ +param( + [string]$OutputPath = "release\algorithm-demo-standalone.zip" +) + +$ErrorActionPreference = "Stop" + +$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..") +$releaseRoot = Join-Path $repoRoot "release" +$stageDir = Join-Path $releaseRoot "algorithm-demo-standalone" +$zipPath = Join-Path $repoRoot $OutputPath + +function Assert-InRepo([string]$PathToCheck) { + $resolved = [System.IO.Path]::GetFullPath($PathToCheck) + $root = [System.IO.Path]::GetFullPath($repoRoot) + if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) { + throw "Refusing to operate outside repository: $resolved" + } +} + +Assert-InRepo $stageDir +Assert-InRepo $zipPath + +Push-Location (Join-Path $repoRoot "frontend") +try { + npm run build +} +finally { + Pop-Location +} + +New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null + +if (Test-Path $stageDir) { + Remove-Item -LiteralPath $stageDir -Recurse -Force +} +if (Test-Path $zipPath) { + Remove-Item -LiteralPath $zipPath -Force +} + +New-Item -ItemType Directory -Force -Path $stageDir | Out-Null +New-Item -ItemType Directory -Force -Path (Join-Path $stageDir "data") | Out-Null + +Copy-Item -Recurse -Path (Join-Path $repoRoot "frontend\dist") -Destination (Join-Path $stageDir "frontend") +Copy-Item -Path (Join-Path $repoRoot "src\demo_service.py") -Destination (Join-Path $stageDir "demo_service.py") +Copy-Item -Path (Join-Path $repoRoot "data\demo_equipment_costs.csv") -Destination (Join-Path $stageDir "data\demo_equipment_costs.csv") +Copy-Item -Path (Join-Path $repoRoot "demo_standalone\server.py") -Destination (Join-Path $stageDir "server.py") +Copy-Item -Path (Join-Path $repoRoot "demo_standalone\requirements.txt") -Destination (Join-Path $stageDir "requirements.txt") +Copy-Item -Path (Join-Path $repoRoot "demo_standalone\start_demo.bat") -Destination (Join-Path $stageDir "start_demo.bat") +Copy-Item -Path (Join-Path $repoRoot "demo_standalone\README.md") -Destination (Join-Path $stageDir "README.md") + +Get-ChildItem -Path $stageDir -Recurse -Include "*.map", "__pycache__" | ForEach-Object { + Remove-Item -LiteralPath $_.FullName -Recurse -Force +} + +Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force + +Write-Host "Demo zip created: $zipPath" diff --git a/scripts/build_html5_zip.ps1 b/scripts/build_html5_zip.ps1 new file mode 100644 index 0000000..c0e713e --- /dev/null +++ b/scripts/build_html5_zip.ps1 @@ -0,0 +1,36 @@ +param( + [string]$OutputPath = "release\intelligent-cost-prediction-html5.zip" +) + +$ErrorActionPreference = "Stop" + +$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..") +$sourceDir = Join-Path $repoRoot "html5_cost_prediction" +$releaseRoot = Join-Path $repoRoot "release" +$stageDir = Join-Path $releaseRoot "intelligent-cost-prediction-html5" +$zipPath = Join-Path $repoRoot $OutputPath + +function Assert-InRepo([string]$PathToCheck) { + $resolved = [System.IO.Path]::GetFullPath($PathToCheck) + $root = [System.IO.Path]::GetFullPath($repoRoot) + if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) { + throw "Refusing to operate outside repository: $resolved" + } +} + +Assert-InRepo $stageDir +Assert-InRepo $zipPath + +New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null + +if (Test-Path $stageDir) { + Remove-Item -LiteralPath $stageDir -Recurse -Force +} +if (Test-Path $zipPath) { + Remove-Item -LiteralPath $zipPath -Force +} + +Copy-Item -Recurse -Path $sourceDir -Destination $stageDir +Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force + +Write-Host "HTML5 zip created: $zipPath" diff --git a/src/demo_service.py b/src/demo_service.py new file mode 100644 index 0000000..fa03462 --- /dev/null +++ b/src/demo_service.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor +from sklearn.linear_model import LinearRegression, Ridge +from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score +from sklearn.model_selection import train_test_split +from sklearn.neighbors import KNeighborsRegressor +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVR + + +@dataclass(frozen=True) +class AlgorithmDefinition: + key: str + name: str + english_name: str + family: str + description: str + estimator: Any + + +class DemoModelService: + target_column = "actual_cost" + ignored_columns = {"name", "type", target_column} + + def __init__(self, dataset_path: Path | str | None = None): + root = Path(__file__).resolve().parent.parent + self.dataset_path = Path(dataset_path) if dataset_path else root / "data" / "demo_equipment_costs.csv" + + def get_algorithms(self) -> list[dict[str, str]]: + algorithms, _ = self._available_algorithms() + return [ + { + "key": item.key, + "name": item.name, + "english_name": item.english_name, + "family": item.family, + "description": item.description, + } + for item in algorithms.values() + ] + + def get_dataset_summary(self) -> dict[str, Any]: + frame = self._load_dataset() + feature_columns = self._feature_columns(frame) + return { + "source": "local-file", + "path": str(self.dataset_path), + "row_count": int(len(frame)), + "columns": list(frame.columns), + "features": feature_columns, + "target": self.target_column, + "target_label": "实际成本", + "equipment_types": sorted(frame["type"].dropna().unique().tolist()), + "preview": frame.head(8).to_dict(orient="records"), + } + + def run_demo(self, selected_algorithms: list[str] | None = None) -> dict[str, Any]: + frame = self._load_dataset() + feature_columns = self._feature_columns(frame) + algorithms, availability_warnings = self._available_algorithms() + + requested = selected_algorithms or list(algorithms.keys()) + warnings = list(availability_warnings) + selected = [] + for key in requested: + if key in algorithms: + selected.append(key) + else: + warnings.append(f"算法 '{key}' 不可用,已自动跳过。") + + if not selected: + selected = ["linear"] + warnings.append("所选算法均不可用,已自动使用线性回归。") + + X = frame[feature_columns] + y = frame[self.target_column] + train_x, test_x, train_y, test_y = train_test_split( + X, + y, + test_size=0.3, + random_state=42, + ) + + metrics: dict[str, dict[str, float | str]] = {} + predictions: dict[str, list[float]] = {} + feature_importance: dict[str, list[dict[str, float | str]]] = {} + + for key in selected: + definition = algorithms[key] + model = definition.estimator + model.fit(train_x, train_y) + predicted = model.predict(test_x) + predictions[key] = [float(value) for value in predicted] + metrics[key] = { + "name": definition.name, + "r2": float(r2_score(test_y, predicted)), + "mae": float(mean_absolute_error(test_y, predicted)), + "rmse": float(np.sqrt(mean_squared_error(test_y, predicted))), + } + feature_importance[key] = self._feature_importance(model, feature_columns) + + best_model = min(metrics, key=lambda key: float(metrics[key]["rmse"])) + ordered_test = test_x.copy() + ordered_test["actual"] = test_y + ordered_test["name"] = frame.loc[test_x.index, "name"] + + prediction_points = [] + for position, (index, row) in enumerate(ordered_test.sort_values("actual").iterrows()): + point = { + "name": row["name"], + "actual": float(row["actual"]), + } + for key in selected: + original_position = list(test_x.index).index(index) + point[key] = predictions[key][original_position] + prediction_points.append(point) + + sample = frame.sort_values(self.target_column).iloc[len(frame) // 2] + sample_x = pd.DataFrame([sample[feature_columns].to_dict()]) + sample_predictions = { + key: float(algorithms[key].estimator.fit(X, y).predict(sample_x)[0]) + for key in selected + } + + return { + "source": "local-file", + "dataset": self.get_dataset_summary(), + "algorithms": self.get_algorithms(), + "selected_algorithms": selected, + "best_model": best_model, + "metrics": metrics, + "feature_importance": feature_importance, + "prediction_points": prediction_points, + "sample_prediction": { + "input": sample.drop(labels=[self.target_column]).to_dict(), + "actual": float(sample[self.target_column]), + "predictions": sample_predictions, + }, + "warnings": warnings, + } + + def _load_dataset(self) -> pd.DataFrame: + if not self.dataset_path.exists(): + raise FileNotFoundError(f"Demo dataset not found: {self.dataset_path}") + + frame = pd.read_csv(self.dataset_path) + if self.target_column not in frame.columns: + raise ValueError(f"Demo dataset must include '{self.target_column}'.") + return frame + + def _feature_columns(self, frame: pd.DataFrame) -> list[str]: + columns = [ + column + for column in frame.columns + if column not in self.ignored_columns and pd.api.types.is_numeric_dtype(frame[column]) + ] + if not columns: + raise ValueError("Demo dataset has no numeric feature columns.") + return columns + + def _available_algorithms(self) -> tuple[dict[str, AlgorithmDefinition], list[str]]: + algorithms = { + "linear": AlgorithmDefinition( + "linear", + "线性回归", + "Linear Regression", + "线性模型", + "快速建立基准模型,用于展示参数与成本之间的线性关系。", + Pipeline([("scaler", StandardScaler()), ("model", LinearRegression())]), + ), + "ridge": AlgorithmDefinition( + "ridge", + "岭回归", + "Ridge Regression", + "线性模型", + "带正则化的线性模型,适合特征存在相关性的场景。", + Pipeline([("scaler", StandardScaler()), ("model", Ridge(alpha=1.0))]), + ), + "random_forest": AlgorithmDefinition( + "random_forest", + "随机森林", + "Random Forest", + "树模型集成", + "通过多棵决策树集成预测,能够捕捉非线性特征影响。", + RandomForestRegressor(n_estimators=160, max_depth=6, random_state=42), + ), + "gradient_boosting": AlgorithmDefinition( + "gradient_boosting", + "梯度提升树", + "Gradient Boosting", + "树模型集成", + "逐步修正误差的提升模型,常用于表格数据回归任务。", + GradientBoostingRegressor(n_estimators=120, learning_rate=0.06, max_depth=3, random_state=42), + ), + "svr": AlgorithmDefinition( + "svr", + "支持向量回归", + "Support Vector Regression", + "核方法", + "使用核函数拟合平滑回归关系,适合展示不同算法偏好。", + Pipeline([("scaler", StandardScaler()), ("model", SVR(C=500000, epsilon=50000))]), + ), + "knn": AlgorithmDefinition( + "knn", + "近邻回归", + "KNN Regression", + "实例学习", + "基于相似样本进行预测,便于解释局部相似性。", + Pipeline([("scaler", StandardScaler()), ("model", KNeighborsRegressor(n_neighbors=4))]), + ), + } + warnings = [] + + try: + from xgboost import XGBRegressor + + algorithms["xgboost"] = AlgorithmDefinition( + "xgboost", + "XGBoost", + "XGBoost", + "提升模型", + "面向表格数据的高性能梯度提升实现。", + XGBRegressor( + n_estimators=120, + max_depth=3, + learning_rate=0.05, + subsample=0.9, + colsample_bytree=0.9, + random_state=42, + objective="reg:squarederror", + ), + ) + except Exception: + warnings.append("当前环境未安装 XGBoost,页面已自动隐藏该算法。") + + try: + from lightgbm import LGBMRegressor + + algorithms["lightgbm"] = AlgorithmDefinition( + "lightgbm", + "LightGBM", + "LightGBM", + "提升模型", + "基于直方图优化的快速梯度提升模型。", + LGBMRegressor( + n_estimators=120, + learning_rate=0.05, + max_depth=4, + random_state=42, + verbose=-1, + ), + ) + except Exception: + warnings.append("当前环境未安装 LightGBM,页面已自动隐藏该算法。") + + return algorithms, warnings + + def _feature_importance(self, model: Any, feature_columns: list[str]) -> list[dict[str, float | str]]: + estimator = model + if isinstance(model, Pipeline): + estimator = model.named_steps["model"] + + if hasattr(estimator, "feature_importances_"): + values = estimator.feature_importances_ + elif hasattr(estimator, "coef_"): + values = np.abs(np.ravel(estimator.coef_)) + else: + values = np.zeros(len(feature_columns)) + + total = float(np.sum(values)) + if total > 0: + values = values / total + + ranked = sorted( + [ + {"feature": feature, "importance": float(value)} + for feature, value in zip(feature_columns, values) + ], + key=lambda item: item["importance"], + reverse=True, + ) + return ranked[:8] diff --git a/src/routes.py b/src/routes.py index fbdaa84..3382743 100644 --- a/src/routes.py +++ b/src/routes.py @@ -11,6 +11,7 @@ import json import os from .data_preparation import DataPreparation from .model_trainer import ModelTrainer +from .demo_service import DemoModelService from .logger import setup_logger import torch @@ -58,6 +59,38 @@ def index(): } }) +@api_bp.route('/demo/algorithms', methods=['GET']) +def demo_algorithms(): + """Return algorithms supported by the file-based demo.""" + try: + service = DemoModelService() + return jsonify({'algorithms': service.get_algorithms()}) + except Exception as e: + logger.error(f"Error getting demo algorithms: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/demo/dataset', methods=['GET']) +def demo_dataset(): + """Return the local demo dataset summary without using MySQL.""" + try: + service = DemoModelService() + return jsonify(service.get_dataset_summary()) + except Exception as e: + logger.error(f"Error getting demo dataset: {str(e)}") + return jsonify({'error': str(e)}), 500 + +@api_bp.route('/demo/run', methods=['POST']) +def demo_run(): + """Run selected demo algorithms against the local data file.""" + try: + data = request.get_json(silent=True) or {} + service = DemoModelService() + return jsonify(service.run_demo(data.get('algorithms'))) + except Exception as e: + logger.error(f"Error running demo algorithms: {str(e)}") + logger.error("Detailed traceback:", exc_info=True) + return jsonify({'error': str(e)}), 500 + @api_bp.route('/predict', methods=['POST']) def predict(): """使用最优机器学习模型进行预测""" @@ -1282,4 +1315,4 @@ def analyze_manufacturers(): except Exception as e: logger.error(f"Error analyzing manufacturers: {str(e)}") logger.error("Detailed traceback:", exc_info=True) - return jsonify({'error': str(e)}), 500 \ No newline at end of file + return jsonify({'error': str(e)}), 500 diff --git a/tests/test_demo_routes.py b/tests/test_demo_routes.py new file mode 100644 index 0000000..c462847 --- /dev/null +++ b/tests/test_demo_routes.py @@ -0,0 +1,40 @@ +from src import create_app + + +def test_demo_algorithms_route_returns_available_models(): + app = create_app() + client = app.test_client() + + response = client.get("/api/demo/algorithms") + + assert response.status_code == 200 + payload = response.get_json() + assert any(item["key"] == "random_forest" for item in payload["algorithms"]) + + +def test_demo_dataset_route_returns_local_file_summary(): + app = create_app() + client = app.test_client() + + response = client.get("/api/demo/dataset") + + assert response.status_code == 200 + payload = response.get_json() + assert payload["source"] == "local-file" + assert payload["row_count"] >= 20 + + +def test_demo_run_route_returns_metrics_without_mysql(): + app = create_app() + client = app.test_client() + + response = client.post( + "/api/demo/run", + json={"algorithms": ["linear", "random_forest"]}, + ) + + assert response.status_code == 200 + payload = response.get_json() + assert payload["source"] == "local-file" + assert set(payload["metrics"]) == {"linear", "random_forest"} + assert payload["prediction_points"] diff --git a/tests/test_demo_service.py b/tests/test_demo_service.py new file mode 100644 index 0000000..019b222 --- /dev/null +++ b/tests/test_demo_service.py @@ -0,0 +1,49 @@ +from pathlib import Path + +from src.demo_service import DemoModelService + + +def test_demo_service_loads_local_dataset(): + service = DemoModelService(Path("data/demo_equipment_costs.csv")) + + summary = service.get_dataset_summary() + + assert summary["row_count"] >= 20 + assert "actual_cost" in summary["columns"] + assert summary["target"] == "actual_cost" + assert summary["preview"][0]["name"] + assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"} + + +def test_demo_service_returns_chinese_algorithm_names_with_english_notes(): + service = DemoModelService(Path("data/demo_equipment_costs.csv")) + + algorithms = service.get_algorithms() + + linear = next(item for item in algorithms if item["key"] == "linear") + assert linear["name"] == "线性回归" + assert linear["english_name"] == "Linear Regression" + assert linear["family"] == "线性模型" + + +def test_demo_service_runs_multiple_algorithms(): + service = DemoModelService(Path("data/demo_equipment_costs.csv")) + + result = service.run_demo(["linear", "random_forest", "gradient_boosting"]) + + assert result["source"] == "local-file" + assert result["best_model"] in result["metrics"] + assert len(result["metrics"]) == 3 + assert len(result["prediction_points"]) > 0 + assert len(result["sample_prediction"]["predictions"]) == 3 + for metrics in result["metrics"].values(): + assert {"r2", "mae", "rmse"}.issubset(metrics) + + +def test_demo_service_ignores_unavailable_algorithms(): + service = DemoModelService(Path("data/demo_equipment_costs.csv")) + + result = service.run_demo(["linear", "does_not_exist"]) + + assert list(result["metrics"].keys()) == ["linear"] + assert result["warnings"]