mechine_learning/knn_classification.py
2025-02-05 17:19:24 +08:00

42 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
# 1. 加载数据集
iris = load_iris()
X = iris.data # 特征数据
y = iris.target # 目标标签
# 2. 划分数据集80%训练20%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 特征标准化很重要KNN对特征尺度敏感
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 4. 构建KNN分类模型这里K=3
knn = KNeighborsClassifier(n_neighbors=3)
# 5. 训练模型
knn.fit(X_train, y_train)
# 6. 预测
y_pred = knn.predict(X_test)
# 7. 评估模型性能
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Classification Report:")
print(classification_report(y_test, y_pred))
# 可视化(仅用于二维数据的情况)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='o', label='Predictions')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("KNN Classification (Predictions)")
plt.savefig("./output/knn_c.png")