Add offline intelligent cost prediction demo
This commit is contained in:
parent
485b4e497a
commit
137451ba7a
6
.gitignore
vendored
6
.gitignore
vendored
@ -42,7 +42,8 @@ node_modules
|
||||
/models
|
||||
/logs
|
||||
/uploads
|
||||
/data
|
||||
/data/*
|
||||
!/data/demo_equipment_costs.csv
|
||||
|
||||
# local env files
|
||||
.env.local
|
||||
@ -65,8 +66,9 @@ pnpm-debug.log*
|
||||
*.sw?
|
||||
/frontend/node_modules
|
||||
/frontend/dist
|
||||
/release/
|
||||
|
||||
*.zip
|
||||
*.tar
|
||||
*.gz
|
||||
*.whl
|
||||
*.whl
|
||||
|
||||
27
data/demo_equipment_costs.csv
Normal file
27
data/demo_equipment_costs.csv
Normal file
@ -0,0 +1,27 @@
|
||||
name,type,length_m,width_m,height_m,weight_kg,max_range_km,payload_kg,max_speed_kmh,endurance_min,tech_level,scale_level,supply_chain_level,complexity_score,actual_cost
|
||||
隼击-A,巡飞弹,1.2,1.8,0.32,18,35,4,145,55,6.4,5.8,6.2,5.9,420000
|
||||
隼击-B,巡飞弹,1.5,2.1,0.36,26,48,6,160,70,6.8,6.2,6.4,6.6,610000
|
||||
隼击-C,巡飞弹,1.8,2.5,0.42,34,65,8,175,85,7.2,6.4,6.8,7.1,830000
|
||||
侦察-100,巡飞弹,0.9,1.4,0.25,9,18,2,110,35,5.4,5.1,5.5,4.8,190000
|
||||
侦察-200,巡飞弹,1.1,1.7,0.29,14,28,3,125,48,5.9,5.4,5.7,5.3,310000
|
||||
侦察-300,巡飞弹,1.4,2.0,0.34,22,42,5,150,62,6.3,5.9,6.0,6.1,520000
|
||||
锐蛇-S,巡飞弹,1.7,2.4,0.38,30,58,7,185,76,7.5,6.7,6.9,7.4,940000
|
||||
锐蛇-M,巡飞弹,2.0,2.8,0.46,44,82,10,205,94,8.0,7.1,7.3,8.0,1360000
|
||||
锐蛇-L,巡飞弹,2.4,3.2,0.55,62,120,15,230,125,8.7,7.5,7.8,8.8,2100000
|
||||
鹰眼-1,巡飞弹,1.3,1.9,0.31,20,40,4,155,58,6.6,5.7,6.3,6.0,470000
|
||||
鹰眼-2,巡飞弹,1.6,2.2,0.37,29,57,7,172,78,7.1,6.1,6.6,6.9,760000
|
||||
鹰眼-3,巡飞弹,2.1,2.9,0.49,51,95,12,215,105,8.2,7.0,7.2,8.1,1580000
|
||||
雷霆-122,火箭炮,6.9,2.4,2.8,13500,22,480,72,0,5.8,6.6,6.0,5.5,980000
|
||||
雷霆-160,火箭炮,7.6,2.6,3.0,16800,40,760,68,0,6.4,6.9,6.3,6.1,1450000
|
||||
雷霆-220,火箭炮,8.3,2.8,3.2,21500,70,1200,65,0,7.0,7.1,6.8,7.0,2380000
|
||||
雷霆-300,火箭炮,9.8,3.0,3.4,28500,120,1850,62,0,7.8,7.4,7.2,8.0,4200000
|
||||
山猫-95,火箭炮,6.2,2.3,2.7,11800,18,360,78,0,5.4,6.0,5.7,5.0,740000
|
||||
山猫-120,火箭炮,6.7,2.4,2.8,13000,30,520,75,0,5.9,6.2,6.0,5.6,1050000
|
||||
山猫-200,火箭炮,7.9,2.7,3.1,19800,60,980,70,0,6.8,6.8,6.5,6.7,1980000
|
||||
山猫-300,火箭炮,9.3,2.9,3.3,26000,105,1600,66,0,7.6,7.2,7.0,7.8,3560000
|
||||
弓兵-L,火箭炮,8.8,2.9,3.2,23500,85,1350,69,0,7.2,7.0,6.9,7.3,2860000
|
||||
弓兵-X,火箭炮,10.2,3.1,3.6,31000,150,2100,60,0,8.4,7.8,7.6,8.7,5400000
|
||||
长矛-1,火箭炮,7.1,2.5,2.9,14200,28,560,73,0,6.1,6.4,6.1,5.8,1180000
|
||||
长矛-2,火箭炮,8.1,2.7,3.1,20500,75,1120,68,0,7.1,6.9,6.7,7.1,2420000
|
||||
长矛-3,火箭炮,9.6,3.0,3.5,29200,130,1900,63,0,8.1,7.5,7.4,8.3,4650000
|
||||
擎天-M,火箭炮,10.8,3.2,3.8,34800,180,2450,58,0,8.9,8.0,7.9,9.2,6900000
|
||||
|
13
demo_standalone/README.md
Normal file
13
demo_standalone/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# 机器学习算法演示
|
||||
|
||||
## 运行方式
|
||||
|
||||
1. 解压 zip 文件。
|
||||
2. 双击 `start_demo.bat`。
|
||||
3. 浏览器会自动打开 `http://127.0.0.1:5001/algorithm-demo`。
|
||||
|
||||
## 说明
|
||||
|
||||
- 演示使用 `data/demo_equipment_costs.csv`,不需要 MySQL。
|
||||
- 首次运行会创建 `.venv` 并安装最小 Python 依赖。
|
||||
- 需要本机已安装 Python 3.9 至 3.11。
|
||||
5
demo_standalone/requirements.txt
Normal file
5
demo_standalone/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
flask>=3.1.0
|
||||
flask-cors>=5.0.0
|
||||
numpy>=1.26.0,<2.0.0
|
||||
pandas>=2.2.0
|
||||
scikit-learn>=1.5.2
|
||||
48
demo_standalone/server.py
Normal file
48
demo_standalone/server.py
Normal file
@ -0,0 +1,48 @@
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Flask, jsonify, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
|
||||
from demo_service import DemoModelService
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
STATIC_DIR = BASE_DIR / "frontend"
|
||||
DATASET_PATH = BASE_DIR / "data" / "demo_equipment_costs.csv"
|
||||
|
||||
|
||||
def create_app():
|
||||
app = Flask(__name__, static_folder=None)
|
||||
CORS(app)
|
||||
|
||||
@app.get("/api/demo/algorithms")
|
||||
def demo_algorithms():
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify({"algorithms": service.get_algorithms()})
|
||||
|
||||
@app.get("/api/demo/dataset")
|
||||
def demo_dataset():
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify(service.get_dataset_summary())
|
||||
|
||||
@app.post("/api/demo/run")
|
||||
def demo_run():
|
||||
payload = request.get_json(silent=True) or {}
|
||||
service = DemoModelService(DATASET_PATH)
|
||||
return jsonify(service.run_demo(payload.get("algorithms")))
|
||||
|
||||
@app.get("/")
|
||||
@app.get("/<path:path>")
|
||||
def frontend(path=""):
|
||||
file_path = STATIC_DIR / path
|
||||
if path and file_path.exists() and file_path.is_file():
|
||||
return send_from_directory(STATIC_DIR, path)
|
||||
return send_from_directory(STATIC_DIR, "index.html")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
print("算法演示服务已启动:http://127.0.0.1:5001/algorithm-demo")
|
||||
app.run(host="127.0.0.1", port=5001, debug=False)
|
||||
32
demo_standalone/start_demo.bat
Normal file
32
demo_standalone/start_demo.bat
Normal file
@ -0,0 +1,32 @@
|
||||
@echo off
|
||||
setlocal
|
||||
cd /d "%~dp0"
|
||||
|
||||
where python >nul 2>nul
|
||||
if errorlevel 1 (
|
||||
echo 未找到 Python。请先安装 Python 3.9 至 3.11,然后重新运行本脚本。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if not exist ".venv\Scripts\python.exe" (
|
||||
echo 正在创建演示环境...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo 创建环境失败。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
|
||||
echo 正在安装或检查依赖...
|
||||
".venv\Scripts\python.exe" -m pip install -r requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo 依赖安装失败,请检查网络或 Python 环境。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
start "" http://127.0.0.1:5001/algorithm-demo
|
||||
".venv\Scripts\python.exe" server.py
|
||||
pause
|
||||
57
docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md
Normal file
57
docs/superpowers/plans/2026-04-25-ml-algorithm-demo.md
Normal file
@ -0,0 +1,57 @@
|
||||
# ML Algorithm Demo Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Add a modern demo page that compares common machine learning algorithms using a local data file instead of MySQL.
|
||||
|
||||
**Architecture:** Add an isolated backend demo service that reads `data/demo_equipment_costs.csv`, trains selected regressors in memory, and returns metrics, prediction points, feature importance, and a sample prediction. Add a Vue route that calls the demo API and renders algorithm switching, charts, metrics, and data preview. Existing database-backed pages remain unchanged.
|
||||
|
||||
**Tech Stack:** Flask, pandas, scikit-learn, optional xgboost/lightgbm, Vue 3, Element Plus, ECharts.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Backend Demo Service
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_demo_service.py`
|
||||
- Create: `src/demo_service.py`
|
||||
- Create: `data/demo_equipment_costs.csv`
|
||||
|
||||
- [ ] Write failing tests for data loading, algorithm availability, and training payload shape.
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it fails because `src.demo_service` is missing.
|
||||
- [ ] Implement `DemoModelService` with local CSV loading, selected algorithm training, metric calculation, top feature importance, and fallback algorithms when optional libraries are unavailable.
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py -q` and verify it passes.
|
||||
|
||||
### Task 2: Demo API
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/routes.py`
|
||||
- Test: `tests/test_demo_routes.py`
|
||||
|
||||
- [ ] Write Flask route tests for `GET /api/demo/algorithms`, `GET /api/demo/dataset`, and `POST /api/demo/run`.
|
||||
- [ ] Run `python -m pytest tests/test_demo_routes.py -q` and verify missing routes fail.
|
||||
- [ ] Add demo routes that call `DemoModelService` and do not access MySQL.
|
||||
- [ ] Run the route tests and demo service tests.
|
||||
|
||||
### Task 3: Vue Demo Page
|
||||
|
||||
**Files:**
|
||||
- Create: `frontend/src/views/AlgorithmDemoPage.vue`
|
||||
- Modify: `frontend/src/router/index.js`
|
||||
- Modify: `frontend/src/App.vue`
|
||||
- Modify: `frontend/src/api/index.js`
|
||||
- Modify: `frontend/src/views/HomePage.vue`
|
||||
|
||||
- [ ] Add API helpers for demo algorithms, dataset, and run.
|
||||
- [ ] Add `/algorithm-demo` route and navigation label `算法演示`.
|
||||
- [ ] Build a modern dashboard-style page with algorithm toggles, metric cards, comparison chart, predicted-vs-actual chart, feature importance chart, sample prediction panel, and data preview table.
|
||||
- [ ] Add a home page entry that links to the demo.
|
||||
|
||||
### Task 4: Verification
|
||||
|
||||
**Files:**
|
||||
- No new files.
|
||||
|
||||
- [ ] Run `python -m pytest tests/test_demo_service.py tests/test_demo_routes.py -q`.
|
||||
- [ ] Run `npm run build` in `frontend`.
|
||||
- [ ] Start the app if feasible and confirm the new route is available.
|
||||
@ -9,6 +9,7 @@
|
||||
<el-menu-item index="/">首页</el-menu-item>
|
||||
<el-menu-item index="/predict">成本预测</el-menu-item>
|
||||
<el-menu-item index="/analysis">特征分析</el-menu-item>
|
||||
<el-menu-item index="/algorithm-demo">算法演示</el-menu-item>
|
||||
<el-menu-item index="/training">模型训练</el-menu-item>
|
||||
<el-menu-item index="/models">模型管理</el-menu-item>
|
||||
<el-menu-item index="/datasets">数据集管理</el-menu-item>
|
||||
|
||||
@ -40,4 +40,16 @@ export const updateEquipment = (id, data) => {
|
||||
|
||||
export const deleteEquipment = (id) => {
|
||||
return api.delete(`/data/${id}`)
|
||||
}
|
||||
}
|
||||
|
||||
export const getDemoAlgorithms = () => {
|
||||
return api.get('/demo/algorithms')
|
||||
}
|
||||
|
||||
export const getDemoDataset = () => {
|
||||
return api.get('/demo/dataset')
|
||||
}
|
||||
|
||||
export const runAlgorithmDemo = (data) => {
|
||||
return api.post('/demo/run', data)
|
||||
}
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
export const API_BASE_URL = 'http://localhost:5001/api';
|
||||
const isLocalDevServer = window.location.port === '8080'
|
||||
|
||||
export const API_BASE_URL = isLocalDevServer
|
||||
? 'http://localhost:5001/api'
|
||||
: `${window.location.origin}/api`;
|
||||
|
||||
export const DB_CONFIG = {
|
||||
host: 'localhost',
|
||||
user: 'root',
|
||||
password: '123456',
|
||||
database: 'equipment_cost_db'
|
||||
};
|
||||
};
|
||||
|
||||
@ -5,6 +5,7 @@ import DatasetPage from '@/views/DatasetPage.vue'
|
||||
import PredictPage from '@/views/PredictPage.vue'
|
||||
import AnalysisPage from '@/views/AnalysisPage.vue'
|
||||
import TrainingPage from '@/views/TrainingPage.vue'
|
||||
import AlgorithmDemoPage from '@/views/AlgorithmDemoPage.vue'
|
||||
|
||||
const routes = [
|
||||
{
|
||||
@ -37,6 +38,11 @@ const routes = [
|
||||
name: 'Training',
|
||||
component: TrainingPage
|
||||
},
|
||||
{
|
||||
path: '/algorithm-demo',
|
||||
name: 'AlgorithmDemo',
|
||||
component: AlgorithmDemoPage
|
||||
},
|
||||
{
|
||||
path: '/models',
|
||||
name: 'Models',
|
||||
@ -49,4 +55,4 @@ const router = createRouter({
|
||||
routes
|
||||
})
|
||||
|
||||
export default router
|
||||
export default router
|
||||
|
||||
644
frontend/src/views/AlgorithmDemoPage.vue
Normal file
644
frontend/src/views/AlgorithmDemoPage.vue
Normal file
@ -0,0 +1,644 @@
|
||||
<template>
|
||||
<div class="algorithm-demo-page">
|
||||
<section class="demo-hero">
|
||||
<div>
|
||||
<p class="eyebrow">本地文件算法演示</p>
|
||||
<h1>机器学习算法演示</h1>
|
||||
<p class="hero-copy">
|
||||
使用本地数据文件快速训练和比较常用回归算法,适合客户演示部署。
|
||||
</p>
|
||||
</div>
|
||||
<div class="hero-actions">
|
||||
<el-button type="primary" :loading="loading" @click="runDemo">
|
||||
<el-icon><VideoPlay /></el-icon>
|
||||
运行演示
|
||||
</el-button>
|
||||
<el-tag effect="plain" type="success">无需 MySQL</el-tag>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="control-band">
|
||||
<div class="panel algorithm-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">算法选择</p>
|
||||
<h2>选择算法</h2>
|
||||
</div>
|
||||
<el-button text type="primary" @click="selectRecommended">推荐组合</el-button>
|
||||
</div>
|
||||
<el-checkbox-group v-model="selectedAlgorithms" class="algorithm-grid">
|
||||
<el-checkbox
|
||||
v-for="item in algorithms"
|
||||
:key="item.key"
|
||||
:value="item.key"
|
||||
border
|
||||
>
|
||||
<span class="algorithm-name">{{ item.name }}</span>
|
||||
<small>{{ item.english_name }} · {{ item.family }}</small>
|
||||
</el-checkbox>
|
||||
</el-checkbox-group>
|
||||
<el-alert
|
||||
v-if="warnings.length"
|
||||
class="warning-strip"
|
||||
type="warning"
|
||||
:closable="false"
|
||||
show-icon
|
||||
>
|
||||
<template #title>{{ warnings.join(' ') }}</template>
|
||||
</el-alert>
|
||||
</div>
|
||||
|
||||
<div class="panel dataset-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">数据来源</p>
|
||||
<h2>本地演示数据</h2>
|
||||
</div>
|
||||
<el-tag>{{ dataset.row_count || 0 }} 条</el-tag>
|
||||
</div>
|
||||
<div class="dataset-stats">
|
||||
<div>
|
||||
<strong>{{ dataset.features?.length || 0 }}</strong>
|
||||
<span>特征数</span>
|
||||
</div>
|
||||
<div>
|
||||
<strong>{{ dataset.equipment_types?.length || 0 }}</strong>
|
||||
<span>装备类型</span>
|
||||
</div>
|
||||
<div>
|
||||
<strong>{{ dataset.target_label || '-' }}</strong>
|
||||
<span>预测目标</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section v-if="result" class="metrics-grid">
|
||||
<article
|
||||
v-for="row in metricRows"
|
||||
:key="row.key"
|
||||
class="metric-card"
|
||||
:class="{ active: row.key === result.best_model }"
|
||||
@click="activeAlgorithm = row.key"
|
||||
>
|
||||
<div class="metric-title">
|
||||
<span>{{ row.name }}</span>
|
||||
<el-tag v-if="row.key === result.best_model" size="small" type="success">最佳</el-tag>
|
||||
</div>
|
||||
<strong>{{ formatScore(row.r2) }}</strong>
|
||||
<div class="metric-values">
|
||||
<span>平均绝对误差 {{ formatMoney(row.mae) }}</span>
|
||||
<span>均方根误差 {{ formatMoney(row.rmse) }}</span>
|
||||
</div>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<section v-if="result" class="visual-grid">
|
||||
<div class="panel chart-panel wide">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">效果对比</p>
|
||||
<h2>模型指标对比</h2>
|
||||
</div>
|
||||
</div>
|
||||
<div ref="metricsChartRef" class="chart"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel chart-panel wide">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">预测结果</p>
|
||||
<h2>预测值与真实值</h2>
|
||||
</div>
|
||||
<el-select v-model="activeAlgorithm" size="small" class="algorithm-select">
|
||||
<el-option
|
||||
v-for="row in metricRows"
|
||||
:key="row.key"
|
||||
:label="row.name"
|
||||
:value="row.key"
|
||||
/>
|
||||
</el-select>
|
||||
</div>
|
||||
<div ref="predictionChartRef" class="chart"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel chart-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">模型解释</p>
|
||||
<h2>特征重要性</h2>
|
||||
</div>
|
||||
</div>
|
||||
<div ref="importanceChartRef" class="chart compact"></div>
|
||||
</div>
|
||||
|
||||
<div class="panel sample-panel">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">样例场景</p>
|
||||
<h2>样例装备预测</h2>
|
||||
</div>
|
||||
</div>
|
||||
<dl>
|
||||
<dt>装备名称</dt>
|
||||
<dd>{{ result.sample_prediction.input.name }}</dd>
|
||||
<dt>真实成本</dt>
|
||||
<dd>{{ formatMoney(result.sample_prediction.actual) }}</dd>
|
||||
<dt>当前算法预测</dt>
|
||||
<dd>{{ formatMoney(result.sample_prediction.predictions[activeAlgorithm]) }}</dd>
|
||||
</dl>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section class="panel data-preview">
|
||||
<div class="panel-header">
|
||||
<div>
|
||||
<p class="eyebrow">数据预览</p>
|
||||
<h2>数据文件预览</h2>
|
||||
</div>
|
||||
</div>
|
||||
<el-table :data="dataset.preview || []" height="320" stripe>
|
||||
<el-table-column prop="name" label="名称" min-width="130" fixed />
|
||||
<el-table-column prop="type" label="类型" min-width="150" />
|
||||
<el-table-column prop="weight_kg" label="重量(kg)" min-width="100" />
|
||||
<el-table-column prop="max_range_km" label="射程(km)" min-width="100" />
|
||||
<el-table-column prop="tech_level" label="技术水平" min-width="100" />
|
||||
<el-table-column prop="complexity_score" label="复杂度" min-width="100" />
|
||||
<el-table-column prop="actual_cost" label="实际成本" min-width="130">
|
||||
<template #default="scope">{{ formatMoney(scope.row.actual_cost) }}</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
</section>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { computed, nextTick, onMounted, onUnmounted, ref, watch } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { VideoPlay } from '@element-plus/icons-vue'
|
||||
import * as echarts from 'echarts'
|
||||
import { getDemoAlgorithms, getDemoDataset, runAlgorithmDemo } from '@/api'
|
||||
|
||||
const algorithms = ref([])
|
||||
const dataset = ref({})
|
||||
const selectedAlgorithms = ref(['linear', 'ridge', 'random_forest', 'gradient_boosting'])
|
||||
const activeAlgorithm = ref('random_forest')
|
||||
const result = ref(null)
|
||||
const loading = ref(false)
|
||||
const warnings = ref([])
|
||||
|
||||
const metricsChartRef = ref(null)
|
||||
const predictionChartRef = ref(null)
|
||||
const importanceChartRef = ref(null)
|
||||
const charts = []
|
||||
|
||||
const metricRows = computed(() => {
|
||||
if (!result.value?.metrics) return []
|
||||
return Object.entries(result.value.metrics).map(([key, value]) => ({
|
||||
key,
|
||||
...value
|
||||
}))
|
||||
})
|
||||
|
||||
const activeMetric = computed(() => {
|
||||
return metricRows.value.find((row) => row.key === activeAlgorithm.value) || metricRows.value[0]
|
||||
})
|
||||
|
||||
const selectRecommended = () => {
|
||||
selectedAlgorithms.value = ['linear', 'ridge', 'random_forest', 'gradient_boosting']
|
||||
}
|
||||
|
||||
const loadInitialData = async () => {
|
||||
try {
|
||||
const [algorithmResponse, datasetResponse] = await Promise.all([
|
||||
getDemoAlgorithms(),
|
||||
getDemoDataset()
|
||||
])
|
||||
algorithms.value = algorithmResponse.data.algorithms
|
||||
dataset.value = datasetResponse.data
|
||||
await runDemo()
|
||||
} catch (error) {
|
||||
ElMessage.error('加载演示数据失败')
|
||||
console.error(error)
|
||||
}
|
||||
}
|
||||
|
||||
const runDemo = async () => {
|
||||
if (!selectedAlgorithms.value.length) {
|
||||
ElMessage.warning('请至少选择一个算法')
|
||||
return
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await runAlgorithmDemo({ algorithms: selectedAlgorithms.value })
|
||||
result.value = response.data
|
||||
dataset.value = response.data.dataset
|
||||
warnings.value = response.data.warnings || []
|
||||
activeAlgorithm.value = response.data.best_model
|
||||
await nextTick()
|
||||
renderCharts()
|
||||
} catch (error) {
|
||||
ElMessage.error(error.response?.data?.error || '运行演示失败')
|
||||
console.error(error)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const disposeCharts = () => {
|
||||
while (charts.length) {
|
||||
const chart = charts.pop()
|
||||
if (chart && !chart.isDisposed()) chart.dispose()
|
||||
}
|
||||
}
|
||||
|
||||
const renderCharts = () => {
|
||||
if (!result.value) return
|
||||
disposeCharts()
|
||||
renderMetricsChart()
|
||||
renderPredictionChart()
|
||||
renderImportanceChart()
|
||||
}
|
||||
|
||||
const renderMetricsChart = () => {
|
||||
if (!metricsChartRef.value) return
|
||||
const chart = echarts.init(metricsChartRef.value)
|
||||
charts.push(chart)
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
legend: { top: 0 },
|
||||
grid: { top: 48, left: 56, right: 24, bottom: 36 },
|
||||
xAxis: { type: 'category', data: metricRows.value.map((row) => row.name) },
|
||||
yAxis: [
|
||||
{ type: 'value', name: '决定系数', min: 0 },
|
||||
{ type: 'value', name: '误差' }
|
||||
],
|
||||
series: [
|
||||
{
|
||||
name: '决定系数',
|
||||
type: 'bar',
|
||||
data: metricRows.value.map((row) => Number(row.r2.toFixed(4))),
|
||||
itemStyle: { color: '#2f6fdd' }
|
||||
},
|
||||
{
|
||||
name: '平均绝对误差',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: metricRows.value.map((row) => Math.round(row.mae)),
|
||||
itemStyle: { color: '#16a085' }
|
||||
},
|
||||
{
|
||||
name: '均方根误差',
|
||||
type: 'line',
|
||||
yAxisIndex: 1,
|
||||
data: metricRows.value.map((row) => Math.round(row.rmse)),
|
||||
itemStyle: { color: '#d98b18' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const renderPredictionChart = () => {
|
||||
if (!predictionChartRef.value || !activeMetric.value) return
|
||||
const chart = echarts.init(predictionChartRef.value)
|
||||
charts.push(chart)
|
||||
const points = result.value.prediction_points
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
legend: { top: 0 },
|
||||
grid: { top: 48, left: 68, right: 24, bottom: 46 },
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: points.map((point) => point.name),
|
||||
axisLabel: { rotate: 25 }
|
||||
},
|
||||
yAxis: { type: 'value', name: '成本' },
|
||||
series: [
|
||||
{
|
||||
name: '真实值',
|
||||
type: 'line',
|
||||
smooth: true,
|
||||
data: points.map((point) => point.actual),
|
||||
itemStyle: { color: '#202938' }
|
||||
},
|
||||
{
|
||||
name: activeMetric.value.name,
|
||||
type: 'bar',
|
||||
data: points.map((point) => point[activeAlgorithm.value]),
|
||||
itemStyle: { color: '#2f6fdd' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const renderImportanceChart = () => {
|
||||
if (!importanceChartRef.value || !activeAlgorithm.value) return
|
||||
const chart = echarts.init(importanceChartRef.value)
|
||||
charts.push(chart)
|
||||
const rows = [...(result.value.feature_importance[activeAlgorithm.value] || [])].reverse()
|
||||
chart.setOption({
|
||||
tooltip: { trigger: 'axis' },
|
||||
grid: { top: 20, left: 108, right: 20, bottom: 24 },
|
||||
xAxis: { type: 'value' },
|
||||
yAxis: { type: 'category', data: rows.map((row) => featureName(row.feature)) },
|
||||
series: [
|
||||
{
|
||||
type: 'bar',
|
||||
data: rows.map((row) => Number(row.importance.toFixed(4))),
|
||||
itemStyle: { color: '#16a085' }
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
const featureName = (key) => {
|
||||
const names = {
|
||||
length_m: '长度',
|
||||
width_m: '宽度',
|
||||
height_m: '高度',
|
||||
weight_kg: '重量',
|
||||
max_range_km: '最大射程',
|
||||
payload_kg: '载荷',
|
||||
max_speed_kmh: '最大速度',
|
||||
endurance_min: '续航',
|
||||
tech_level: '技术水平',
|
||||
scale_level: '规模水平',
|
||||
supply_chain_level: '供应链',
|
||||
complexity_score: '复杂度'
|
||||
}
|
||||
return names[key] || key
|
||||
}
|
||||
|
||||
const formatMoney = (value) => {
|
||||
if (value === undefined || value === null) return '-'
|
||||
return Number(value).toLocaleString('zh-CN', {
|
||||
style: 'currency',
|
||||
currency: 'CNY',
|
||||
maximumFractionDigits: 0
|
||||
})
|
||||
}
|
||||
|
||||
const formatScore = (value) => {
|
||||
if (value === undefined || value === null) return '-'
|
||||
return Number(value).toFixed(3)
|
||||
}
|
||||
|
||||
watch(activeAlgorithm, async () => {
|
||||
await nextTick()
|
||||
renderCharts()
|
||||
})
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
charts.forEach((chart) => {
|
||||
if (chart && !chart.isDisposed()) chart.resize()
|
||||
})
|
||||
})
|
||||
|
||||
onMounted(loadInitialData)
|
||||
onUnmounted(disposeCharts)
|
||||
</script>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
.algorithm-demo-page {
|
||||
min-height: calc(100vh - 60px);
|
||||
padding: 24px;
|
||||
color: #202938;
|
||||
background:
|
||||
linear-gradient(180deg, #eef3f8 0%, #f7f9fb 280px),
|
||||
#f7f9fb;
|
||||
}
|
||||
|
||||
.demo-hero,
|
||||
.control-band,
|
||||
.metrics-grid,
|
||||
.visual-grid,
|
||||
.data-preview {
|
||||
max-width: 1440px;
|
||||
margin: 0 auto 18px;
|
||||
}
|
||||
|
||||
.demo-hero {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 20px;
|
||||
min-height: 176px;
|
||||
|
||||
h1 {
|
||||
margin: 6px 0 10px;
|
||||
font-size: 36px;
|
||||
line-height: 1.2;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.hero-copy {
|
||||
max-width: 680px;
|
||||
margin: 0;
|
||||
color: #536273;
|
||||
font-size: 16px;
|
||||
line-height: 1.7;
|
||||
}
|
||||
|
||||
.hero-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.eyebrow {
|
||||
margin: 0;
|
||||
color: #2f6fdd;
|
||||
font-size: 12px;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.control-band,
|
||||
.visual-grid {
|
||||
display: grid;
|
||||
grid-template-columns: minmax(0, 1.35fr) minmax(320px, 0.65fr);
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.panel,
|
||||
.metric-card {
|
||||
border: 1px solid #dfe6ef;
|
||||
border-radius: 8px;
|
||||
background: #fff;
|
||||
box-shadow: 0 10px 28px rgba(32, 41, 56, 0.06);
|
||||
}
|
||||
|
||||
.panel {
|
||||
padding: 18px;
|
||||
}
|
||||
|
||||
.panel-header,
|
||||
.metric-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
|
||||
h2 {
|
||||
margin: 4px 0 0;
|
||||
font-size: 18px;
|
||||
line-height: 1.3;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(190px, 1fr));
|
||||
gap: 10px;
|
||||
margin-top: 16px;
|
||||
|
||||
:deep(.el-checkbox) {
|
||||
width: 100%;
|
||||
height: 64px;
|
||||
margin: 0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
:deep(.el-checkbox__label) {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
line-height: 1.2;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-name {
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.algorithm-grid small {
|
||||
color: #6b7786;
|
||||
}
|
||||
|
||||
.warning-strip {
|
||||
margin-top: 14px;
|
||||
}
|
||||
|
||||
.dataset-stats {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(3, 1fr);
|
||||
gap: 10px;
|
||||
margin-top: 18px;
|
||||
|
||||
div {
|
||||
padding: 14px;
|
||||
border-radius: 8px;
|
||||
background: #f2f6fa;
|
||||
}
|
||||
|
||||
strong {
|
||||
display: block;
|
||||
margin-bottom: 6px;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
span {
|
||||
color: #667485;
|
||||
font-size: 13px;
|
||||
}
|
||||
}
|
||||
|
||||
.metrics-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(220px, 1fr));
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.metric-card {
|
||||
padding: 16px;
|
||||
cursor: pointer;
|
||||
transition: border-color 0.2s ease, transform 0.2s ease;
|
||||
|
||||
&.active,
|
||||
&:hover {
|
||||
border-color: #2f6fdd;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
strong {
|
||||
display: block;
|
||||
margin: 14px 0;
|
||||
font-size: 30px;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
}
|
||||
|
||||
.metric-values {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
color: #617080;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.chart-panel.wide {
|
||||
grid-column: span 2;
|
||||
}
|
||||
|
||||
.chart {
|
||||
width: 100%;
|
||||
height: 360px;
|
||||
|
||||
&.compact {
|
||||
height: 330px;
|
||||
}
|
||||
}
|
||||
|
||||
.algorithm-select {
|
||||
width: 220px;
|
||||
}
|
||||
|
||||
.sample-panel dl {
|
||||
display: grid;
|
||||
grid-template-columns: 110px minmax(0, 1fr);
|
||||
gap: 14px 10px;
|
||||
margin: 20px 0 0;
|
||||
}
|
||||
|
||||
.sample-panel dt {
|
||||
color: #667485;
|
||||
}
|
||||
|
||||
.sample-panel dd {
|
||||
margin: 0;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
.algorithm-demo-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.demo-hero,
|
||||
.control-band,
|
||||
.visual-grid {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
|
||||
.demo-hero {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
|
||||
h1 {
|
||||
font-size: 28px;
|
||||
}
|
||||
}
|
||||
|
||||
.chart-panel.wide {
|
||||
grid-column: span 1;
|
||||
}
|
||||
|
||||
.dataset-stats {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@ -26,6 +26,13 @@
|
||||
<p>训练和优化预测模型</p>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-card @click="$router.push('/algorithm-demo')">
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
<h3>算法演示</h3>
|
||||
<p>切换常用机器学习算法并对比预测效果</p>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-card @click="$router.push('/models')">
|
||||
<el-icon><Management /></el-icon>
|
||||
@ -53,7 +60,7 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-plus/icons-vue'
|
||||
import { Money, DataAnalysis, Monitor, Management, Collection, TrendCharts } from '@element-plus/icons-vue'
|
||||
</script>
|
||||
|
||||
<style lang="scss" scoped>
|
||||
@ -98,4 +105,4 @@ import { Money, DataAnalysis, Monitor, Management, Collection } from '@element-p
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</style>
|
||||
|
||||
11
html5_cost_prediction/README.txt
Normal file
11
html5_cost_prediction/README.txt
Normal file
@ -0,0 +1,11 @@
|
||||
智能成本预测系统 - HTML5离线版
|
||||
|
||||
运行方式:
|
||||
1. 解压 zip 文件。
|
||||
2. 双击 index.html。
|
||||
|
||||
说明:
|
||||
- 不需要 Python。
|
||||
- 不需要数据库。
|
||||
- 不需要联网。
|
||||
- 页面内置样例数据和模型效果,用于客户现场展示不同模型的预测差异。
|
||||
1324
html5_cost_prediction/index.html
Normal file
1324
html5_cost_prediction/index.html
Normal file
File diff suppressed because one or more lines are too long
57
scripts/build_demo_zip.ps1
Normal file
57
scripts/build_demo_zip.ps1
Normal file
@ -0,0 +1,57 @@
|
||||
param(
|
||||
[string]$OutputPath = "release\algorithm-demo-standalone.zip"
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
|
||||
$releaseRoot = Join-Path $repoRoot "release"
|
||||
$stageDir = Join-Path $releaseRoot "algorithm-demo-standalone"
|
||||
$zipPath = Join-Path $repoRoot $OutputPath
|
||||
|
||||
function Assert-InRepo([string]$PathToCheck) {
|
||||
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
|
||||
$root = [System.IO.Path]::GetFullPath($repoRoot)
|
||||
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
|
||||
throw "Refusing to operate outside repository: $resolved"
|
||||
}
|
||||
}
|
||||
|
||||
Assert-InRepo $stageDir
|
||||
Assert-InRepo $zipPath
|
||||
|
||||
Push-Location (Join-Path $repoRoot "frontend")
|
||||
try {
|
||||
npm run build
|
||||
}
|
||||
finally {
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
|
||||
|
||||
if (Test-Path $stageDir) {
|
||||
Remove-Item -LiteralPath $stageDir -Recurse -Force
|
||||
}
|
||||
if (Test-Path $zipPath) {
|
||||
Remove-Item -LiteralPath $zipPath -Force
|
||||
}
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $stageDir | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path (Join-Path $stageDir "data") | Out-Null
|
||||
|
||||
Copy-Item -Recurse -Path (Join-Path $repoRoot "frontend\dist") -Destination (Join-Path $stageDir "frontend")
|
||||
Copy-Item -Path (Join-Path $repoRoot "src\demo_service.py") -Destination (Join-Path $stageDir "demo_service.py")
|
||||
Copy-Item -Path (Join-Path $repoRoot "data\demo_equipment_costs.csv") -Destination (Join-Path $stageDir "data\demo_equipment_costs.csv")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\server.py") -Destination (Join-Path $stageDir "server.py")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\requirements.txt") -Destination (Join-Path $stageDir "requirements.txt")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\start_demo.bat") -Destination (Join-Path $stageDir "start_demo.bat")
|
||||
Copy-Item -Path (Join-Path $repoRoot "demo_standalone\README.md") -Destination (Join-Path $stageDir "README.md")
|
||||
|
||||
Get-ChildItem -Path $stageDir -Recurse -Include "*.map", "__pycache__" | ForEach-Object {
|
||||
Remove-Item -LiteralPath $_.FullName -Recurse -Force
|
||||
}
|
||||
|
||||
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
|
||||
|
||||
Write-Host "Demo zip created: $zipPath"
|
||||
36
scripts/build_html5_zip.ps1
Normal file
36
scripts/build_html5_zip.ps1
Normal file
@ -0,0 +1,36 @@
|
||||
param(
|
||||
[string]$OutputPath = "release\intelligent-cost-prediction-html5.zip"
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$repoRoot = Resolve-Path (Join-Path $PSScriptRoot "..")
|
||||
$sourceDir = Join-Path $repoRoot "html5_cost_prediction"
|
||||
$releaseRoot = Join-Path $repoRoot "release"
|
||||
$stageDir = Join-Path $releaseRoot "intelligent-cost-prediction-html5"
|
||||
$zipPath = Join-Path $repoRoot $OutputPath
|
||||
|
||||
function Assert-InRepo([string]$PathToCheck) {
|
||||
$resolved = [System.IO.Path]::GetFullPath($PathToCheck)
|
||||
$root = [System.IO.Path]::GetFullPath($repoRoot)
|
||||
if (-not $resolved.StartsWith($root, [System.StringComparison]::OrdinalIgnoreCase)) {
|
||||
throw "Refusing to operate outside repository: $resolved"
|
||||
}
|
||||
}
|
||||
|
||||
Assert-InRepo $stageDir
|
||||
Assert-InRepo $zipPath
|
||||
|
||||
New-Item -ItemType Directory -Force -Path $releaseRoot | Out-Null
|
||||
|
||||
if (Test-Path $stageDir) {
|
||||
Remove-Item -LiteralPath $stageDir -Recurse -Force
|
||||
}
|
||||
if (Test-Path $zipPath) {
|
||||
Remove-Item -LiteralPath $zipPath -Force
|
||||
}
|
||||
|
||||
Copy-Item -Recurse -Path $sourceDir -Destination $stageDir
|
||||
Compress-Archive -Path (Join-Path $stageDir "*") -DestinationPath $zipPath -Force
|
||||
|
||||
Write-Host "HTML5 zip created: $zipPath"
|
||||
290
src/demo_service.py
Normal file
290
src/demo_service.py
Normal file
@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
|
||||
from sklearn.linear_model import LinearRegression, Ridge
|
||||
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neighbors import KNeighborsRegressor
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVR
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgorithmDefinition:
|
||||
key: str
|
||||
name: str
|
||||
english_name: str
|
||||
family: str
|
||||
description: str
|
||||
estimator: Any
|
||||
|
||||
|
||||
class DemoModelService:
|
||||
target_column = "actual_cost"
|
||||
ignored_columns = {"name", "type", target_column}
|
||||
|
||||
def __init__(self, dataset_path: Path | str | None = None):
|
||||
root = Path(__file__).resolve().parent.parent
|
||||
self.dataset_path = Path(dataset_path) if dataset_path else root / "data" / "demo_equipment_costs.csv"
|
||||
|
||||
def get_algorithms(self) -> list[dict[str, str]]:
|
||||
algorithms, _ = self._available_algorithms()
|
||||
return [
|
||||
{
|
||||
"key": item.key,
|
||||
"name": item.name,
|
||||
"english_name": item.english_name,
|
||||
"family": item.family,
|
||||
"description": item.description,
|
||||
}
|
||||
for item in algorithms.values()
|
||||
]
|
||||
|
||||
def get_dataset_summary(self) -> dict[str, Any]:
|
||||
frame = self._load_dataset()
|
||||
feature_columns = self._feature_columns(frame)
|
||||
return {
|
||||
"source": "local-file",
|
||||
"path": str(self.dataset_path),
|
||||
"row_count": int(len(frame)),
|
||||
"columns": list(frame.columns),
|
||||
"features": feature_columns,
|
||||
"target": self.target_column,
|
||||
"target_label": "实际成本",
|
||||
"equipment_types": sorted(frame["type"].dropna().unique().tolist()),
|
||||
"preview": frame.head(8).to_dict(orient="records"),
|
||||
}
|
||||
|
||||
def run_demo(self, selected_algorithms: list[str] | None = None) -> dict[str, Any]:
|
||||
frame = self._load_dataset()
|
||||
feature_columns = self._feature_columns(frame)
|
||||
algorithms, availability_warnings = self._available_algorithms()
|
||||
|
||||
requested = selected_algorithms or list(algorithms.keys())
|
||||
warnings = list(availability_warnings)
|
||||
selected = []
|
||||
for key in requested:
|
||||
if key in algorithms:
|
||||
selected.append(key)
|
||||
else:
|
||||
warnings.append(f"算法 '{key}' 不可用,已自动跳过。")
|
||||
|
||||
if not selected:
|
||||
selected = ["linear"]
|
||||
warnings.append("所选算法均不可用,已自动使用线性回归。")
|
||||
|
||||
X = frame[feature_columns]
|
||||
y = frame[self.target_column]
|
||||
train_x, test_x, train_y, test_y = train_test_split(
|
||||
X,
|
||||
y,
|
||||
test_size=0.3,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
metrics: dict[str, dict[str, float | str]] = {}
|
||||
predictions: dict[str, list[float]] = {}
|
||||
feature_importance: dict[str, list[dict[str, float | str]]] = {}
|
||||
|
||||
for key in selected:
|
||||
definition = algorithms[key]
|
||||
model = definition.estimator
|
||||
model.fit(train_x, train_y)
|
||||
predicted = model.predict(test_x)
|
||||
predictions[key] = [float(value) for value in predicted]
|
||||
metrics[key] = {
|
||||
"name": definition.name,
|
||||
"r2": float(r2_score(test_y, predicted)),
|
||||
"mae": float(mean_absolute_error(test_y, predicted)),
|
||||
"rmse": float(np.sqrt(mean_squared_error(test_y, predicted))),
|
||||
}
|
||||
feature_importance[key] = self._feature_importance(model, feature_columns)
|
||||
|
||||
best_model = min(metrics, key=lambda key: float(metrics[key]["rmse"]))
|
||||
ordered_test = test_x.copy()
|
||||
ordered_test["actual"] = test_y
|
||||
ordered_test["name"] = frame.loc[test_x.index, "name"]
|
||||
|
||||
prediction_points = []
|
||||
for position, (index, row) in enumerate(ordered_test.sort_values("actual").iterrows()):
|
||||
point = {
|
||||
"name": row["name"],
|
||||
"actual": float(row["actual"]),
|
||||
}
|
||||
for key in selected:
|
||||
original_position = list(test_x.index).index(index)
|
||||
point[key] = predictions[key][original_position]
|
||||
prediction_points.append(point)
|
||||
|
||||
sample = frame.sort_values(self.target_column).iloc[len(frame) // 2]
|
||||
sample_x = pd.DataFrame([sample[feature_columns].to_dict()])
|
||||
sample_predictions = {
|
||||
key: float(algorithms[key].estimator.fit(X, y).predict(sample_x)[0])
|
||||
for key in selected
|
||||
}
|
||||
|
||||
return {
|
||||
"source": "local-file",
|
||||
"dataset": self.get_dataset_summary(),
|
||||
"algorithms": self.get_algorithms(),
|
||||
"selected_algorithms": selected,
|
||||
"best_model": best_model,
|
||||
"metrics": metrics,
|
||||
"feature_importance": feature_importance,
|
||||
"prediction_points": prediction_points,
|
||||
"sample_prediction": {
|
||||
"input": sample.drop(labels=[self.target_column]).to_dict(),
|
||||
"actual": float(sample[self.target_column]),
|
||||
"predictions": sample_predictions,
|
||||
},
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
def _load_dataset(self) -> pd.DataFrame:
|
||||
if not self.dataset_path.exists():
|
||||
raise FileNotFoundError(f"Demo dataset not found: {self.dataset_path}")
|
||||
|
||||
frame = pd.read_csv(self.dataset_path)
|
||||
if self.target_column not in frame.columns:
|
||||
raise ValueError(f"Demo dataset must include '{self.target_column}'.")
|
||||
return frame
|
||||
|
||||
def _feature_columns(self, frame: pd.DataFrame) -> list[str]:
|
||||
columns = [
|
||||
column
|
||||
for column in frame.columns
|
||||
if column not in self.ignored_columns and pd.api.types.is_numeric_dtype(frame[column])
|
||||
]
|
||||
if not columns:
|
||||
raise ValueError("Demo dataset has no numeric feature columns.")
|
||||
return columns
|
||||
|
||||
def _available_algorithms(self) -> tuple[dict[str, AlgorithmDefinition], list[str]]:
|
||||
algorithms = {
|
||||
"linear": AlgorithmDefinition(
|
||||
"linear",
|
||||
"线性回归",
|
||||
"Linear Regression",
|
||||
"线性模型",
|
||||
"快速建立基准模型,用于展示参数与成本之间的线性关系。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", LinearRegression())]),
|
||||
),
|
||||
"ridge": AlgorithmDefinition(
|
||||
"ridge",
|
||||
"岭回归",
|
||||
"Ridge Regression",
|
||||
"线性模型",
|
||||
"带正则化的线性模型,适合特征存在相关性的场景。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", Ridge(alpha=1.0))]),
|
||||
),
|
||||
"random_forest": AlgorithmDefinition(
|
||||
"random_forest",
|
||||
"随机森林",
|
||||
"Random Forest",
|
||||
"树模型集成",
|
||||
"通过多棵决策树集成预测,能够捕捉非线性特征影响。",
|
||||
RandomForestRegressor(n_estimators=160, max_depth=6, random_state=42),
|
||||
),
|
||||
"gradient_boosting": AlgorithmDefinition(
|
||||
"gradient_boosting",
|
||||
"梯度提升树",
|
||||
"Gradient Boosting",
|
||||
"树模型集成",
|
||||
"逐步修正误差的提升模型,常用于表格数据回归任务。",
|
||||
GradientBoostingRegressor(n_estimators=120, learning_rate=0.06, max_depth=3, random_state=42),
|
||||
),
|
||||
"svr": AlgorithmDefinition(
|
||||
"svr",
|
||||
"支持向量回归",
|
||||
"Support Vector Regression",
|
||||
"核方法",
|
||||
"使用核函数拟合平滑回归关系,适合展示不同算法偏好。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", SVR(C=500000, epsilon=50000))]),
|
||||
),
|
||||
"knn": AlgorithmDefinition(
|
||||
"knn",
|
||||
"近邻回归",
|
||||
"KNN Regression",
|
||||
"实例学习",
|
||||
"基于相似样本进行预测,便于解释局部相似性。",
|
||||
Pipeline([("scaler", StandardScaler()), ("model", KNeighborsRegressor(n_neighbors=4))]),
|
||||
),
|
||||
}
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
from xgboost import XGBRegressor
|
||||
|
||||
algorithms["xgboost"] = AlgorithmDefinition(
|
||||
"xgboost",
|
||||
"XGBoost",
|
||||
"XGBoost",
|
||||
"提升模型",
|
||||
"面向表格数据的高性能梯度提升实现。",
|
||||
XGBRegressor(
|
||||
n_estimators=120,
|
||||
max_depth=3,
|
||||
learning_rate=0.05,
|
||||
subsample=0.9,
|
||||
colsample_bytree=0.9,
|
||||
random_state=42,
|
||||
objective="reg:squarederror",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
warnings.append("当前环境未安装 XGBoost,页面已自动隐藏该算法。")
|
||||
|
||||
try:
|
||||
from lightgbm import LGBMRegressor
|
||||
|
||||
algorithms["lightgbm"] = AlgorithmDefinition(
|
||||
"lightgbm",
|
||||
"LightGBM",
|
||||
"LightGBM",
|
||||
"提升模型",
|
||||
"基于直方图优化的快速梯度提升模型。",
|
||||
LGBMRegressor(
|
||||
n_estimators=120,
|
||||
learning_rate=0.05,
|
||||
max_depth=4,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
warnings.append("当前环境未安装 LightGBM,页面已自动隐藏该算法。")
|
||||
|
||||
return algorithms, warnings
|
||||
|
||||
def _feature_importance(self, model: Any, feature_columns: list[str]) -> list[dict[str, float | str]]:
|
||||
estimator = model
|
||||
if isinstance(model, Pipeline):
|
||||
estimator = model.named_steps["model"]
|
||||
|
||||
if hasattr(estimator, "feature_importances_"):
|
||||
values = estimator.feature_importances_
|
||||
elif hasattr(estimator, "coef_"):
|
||||
values = np.abs(np.ravel(estimator.coef_))
|
||||
else:
|
||||
values = np.zeros(len(feature_columns))
|
||||
|
||||
total = float(np.sum(values))
|
||||
if total > 0:
|
||||
values = values / total
|
||||
|
||||
ranked = sorted(
|
||||
[
|
||||
{"feature": feature, "importance": float(value)}
|
||||
for feature, value in zip(feature_columns, values)
|
||||
],
|
||||
key=lambda item: item["importance"],
|
||||
reverse=True,
|
||||
)
|
||||
return ranked[:8]
|
||||
@ -11,6 +11,7 @@ import json
|
||||
import os
|
||||
from .data_preparation import DataPreparation
|
||||
from .model_trainer import ModelTrainer
|
||||
from .demo_service import DemoModelService
|
||||
from .logger import setup_logger
|
||||
import torch
|
||||
|
||||
@ -58,6 +59,38 @@ def index():
|
||||
}
|
||||
})
|
||||
|
||||
@api_bp.route('/demo/algorithms', methods=['GET'])
|
||||
def demo_algorithms():
|
||||
"""Return algorithms supported by the file-based demo."""
|
||||
try:
|
||||
service = DemoModelService()
|
||||
return jsonify({'algorithms': service.get_algorithms()})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting demo algorithms: {str(e)}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/demo/dataset', methods=['GET'])
|
||||
def demo_dataset():
|
||||
"""Return the local demo dataset summary without using MySQL."""
|
||||
try:
|
||||
service = DemoModelService()
|
||||
return jsonify(service.get_dataset_summary())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting demo dataset: {str(e)}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/demo/run', methods=['POST'])
|
||||
def demo_run():
|
||||
"""Run selected demo algorithms against the local data file."""
|
||||
try:
|
||||
data = request.get_json(silent=True) or {}
|
||||
service = DemoModelService()
|
||||
return jsonify(service.run_demo(data.get('algorithms')))
|
||||
except Exception as e:
|
||||
logger.error(f"Error running demo algorithms: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@api_bp.route('/predict', methods=['POST'])
|
||||
def predict():
|
||||
"""使用最优机器学习模型进行预测"""
|
||||
@ -1282,4 +1315,4 @@ def analyze_manufacturers():
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing manufacturers: {str(e)}")
|
||||
logger.error("Detailed traceback:", exc_info=True)
|
||||
return jsonify({'error': str(e)}), 500
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
40
tests/test_demo_routes.py
Normal file
40
tests/test_demo_routes.py
Normal file
@ -0,0 +1,40 @@
|
||||
from src import create_app
|
||||
|
||||
|
||||
def test_demo_algorithms_route_returns_available_models():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.get("/api/demo/algorithms")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert any(item["key"] == "random_forest" for item in payload["algorithms"])
|
||||
|
||||
|
||||
def test_demo_dataset_route_returns_local_file_summary():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.get("/api/demo/dataset")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["source"] == "local-file"
|
||||
assert payload["row_count"] >= 20
|
||||
|
||||
|
||||
def test_demo_run_route_returns_metrics_without_mysql():
|
||||
app = create_app()
|
||||
client = app.test_client()
|
||||
|
||||
response = client.post(
|
||||
"/api/demo/run",
|
||||
json={"algorithms": ["linear", "random_forest"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.get_json()
|
||||
assert payload["source"] == "local-file"
|
||||
assert set(payload["metrics"]) == {"linear", "random_forest"}
|
||||
assert payload["prediction_points"]
|
||||
49
tests/test_demo_service.py
Normal file
49
tests/test_demo_service.py
Normal file
@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.demo_service import DemoModelService
|
||||
|
||||
|
||||
def test_demo_service_loads_local_dataset():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
summary = service.get_dataset_summary()
|
||||
|
||||
assert summary["row_count"] >= 20
|
||||
assert "actual_cost" in summary["columns"]
|
||||
assert summary["target"] == "actual_cost"
|
||||
assert summary["preview"][0]["name"]
|
||||
assert summary["preview"][0]["type"] in {"巡飞弹", "火箭炮"}
|
||||
|
||||
|
||||
def test_demo_service_returns_chinese_algorithm_names_with_english_notes():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
algorithms = service.get_algorithms()
|
||||
|
||||
linear = next(item for item in algorithms if item["key"] == "linear")
|
||||
assert linear["name"] == "线性回归"
|
||||
assert linear["english_name"] == "Linear Regression"
|
||||
assert linear["family"] == "线性模型"
|
||||
|
||||
|
||||
def test_demo_service_runs_multiple_algorithms():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
result = service.run_demo(["linear", "random_forest", "gradient_boosting"])
|
||||
|
||||
assert result["source"] == "local-file"
|
||||
assert result["best_model"] in result["metrics"]
|
||||
assert len(result["metrics"]) == 3
|
||||
assert len(result["prediction_points"]) > 0
|
||||
assert len(result["sample_prediction"]["predictions"]) == 3
|
||||
for metrics in result["metrics"].values():
|
||||
assert {"r2", "mae", "rmse"}.issubset(metrics)
|
||||
|
||||
|
||||
def test_demo_service_ignores_unavailable_algorithms():
|
||||
service = DemoModelService(Path("data/demo_equipment_costs.csv"))
|
||||
|
||||
result = service.run_demo(["linear", "does_not_exist"])
|
||||
|
||||
assert list(result["metrics"].keys()) == ["linear"]
|
||||
assert result["warnings"]
|
||||
Loading…
Reference in New Issue
Block a user