mechine_learning/decision_tree_regression.py

44 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
import numpy as np
import matplotlib.pyplot as plt
# 1. 创建回归数据集
X, y = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=42)
# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 创建决策树回归模型
regressor = DecisionTreeRegressor(random_state=42)
# 4. 训练模型
regressor.fit(X_train, y_train)
# 5. 在测试集上进行预测
y_pred = regressor.predict(X_test)
# 6. 可视化结果
plt.figure(figsize=(10,6))
# 真实数据点
plt.scatter(X_test, y_test, color='blue', label='true_label')
# 预测结果
plt.scatter(X_test, y_pred, color='red', label='predict_label')
# 绘制回归曲线此处对X进行排序使曲线平滑
X_grid = np.arange(min(X_test), max(X_test), 0.01).reshape(-1, 1)
y_grid = regressor.predict(X_grid)
plt.plot(X_grid, y_grid, color='green', label='regression_curve')
plt.title("decision_tree_regression")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.savefig("./output/decision_tree_r.png")
# 7. 打印模型性能(可选)
from sklearn.metrics import mean_squared_error, r2_score
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"均方误差MSE: {mse:.2f}")
print(f"R^2 分数: {r2:.2f}")