mechine_learning/knn_regression.py
2025-02-05 17:33:29 +08:00

47 lines
1.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
# 1. 加载数据集(使用加州房价数据集)
data = fetch_california_housing(data_home="/home/admin-root/haotian/ML/dataset")
X = data.data # 特征数据
y = data.target # 目标数据(房价)
# 2. 划分数据集80%训练20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 特征标准化KNN对特征尺度敏感
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 4. 构建KNN回归模型这里K=5
knn = KNeighborsRegressor(n_neighbors=5)
# 5. 训练模型
knn.fit(X_train, y_train)
# 6. 预测
y_pred = knn.predict(X_test)
# 7. 评估模型性能
mae = mean_absolute_error(y_test, y_pred)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
# 打印评估指标
print(f"Mean Absolute Error: {mae:.4f}")
print(f"Mean Squared Error: {mse:.4f}")
print(f"R^2 Score: {r2:.4f}")
# 可视化预测结果 vs 实际值(仅在二维数据时更直观)
plt.scatter(y_test, y_pred, c='blue', edgecolors='k', alpha=0.7)
plt.xlabel("True Values")
plt.ylabel("Predictions")
plt.title("KNN Regression - True vs Predicted")
plt.savefig("./output/knn_r.png")