mechine_learning/random_forest_regression.py

75 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from sklearn.ensemble import RandomForestRegressor
# from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 指定本地数据路径
data_path = "/home/admin-root/haotian/ML/dataset/CaliforniaHousing"
# 加载数据文件
# X = np.loadtxt(f"{data_path}/cal_housing.data", delimiter=" ")
X = np.genfromtxt(f"{data_path}/cal_housing.data", delimiter=',', dtype=float)
# 读取特征名称 (如果 domain 文件是特征的描述)
with open(f"{data_path}/cal_housing.domain", "r") as f:
feature_names = [line.strip() for line in f.readlines()]
# 显示特征名称,检查是否正确
print("Feature Names:", feature_names)
# 假设目标变量是 `X` 的最后一列,按列拆分
y = X[:, -1] # 目标变量(房价)
X = X[:, :-1] # 特征数据
# # 1. 加载数据集(使用加州房价数据集)
# data = fetch_california_housing(data_home="/home/admin-root/haotian/ML/dataset/CaliforniaHousing")
# X = data.data # 特征
# y = data.target # 目标变量(房价)
# 2. 划分训练集和测试集80% 训练20% 测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 创建随机森林回归模型
rf_regressor = RandomForestRegressor(n_estimators=100, random_state=42)
# 4. 训练模型
rf_regressor.fit(X_train, y_train)
# 5. 进行预测
y_pred = rf_regressor.predict(X_test)
# 6. 评估模型
mae = mean_absolute_error(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
r2 = r2_score(y_test, y_pred)
print(f"均绝对误差MAE{mae:.4f}")
print(f"均方误差MSE{mse:.4f}")
print(f"均方根误差RMSE{rmse:.4f}")
print(f"R² 评分R2 Score{r2:.4f}")
# 7. 可视化预测结果
plt.figure(figsize=(8, 6))
plt.scatter(y_test, y_pred, alpha=0.6)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], '--r') # 理想预测线
plt.xlabel("true_label")
plt.ylabel("predict_label")
plt.title("random_forest_regression")
plt.savefig('./output/random_forest_r.png')
# # 8. 特征重要性可视化
# feature_importances = rf_regressor.feature_importances_
# plt.figure(figsize=(10, 6))
# plt.barh(feature_names, feature_importances)
# plt.xlabel("feature importance")
# plt.ylabel("feature name")
# plt.title("random_forest_regression - feature_importance")
# plt.savefig('./output/random_forest_r_feature.png')