42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# 导入必要的库
|
||
from sklearn.datasets import load_iris
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.neighbors import KNeighborsClassifier
|
||
from sklearn.metrics import classification_report, accuracy_score
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 1. 加载数据集
|
||
iris = load_iris()
|
||
X = iris.data # 特征数据
|
||
y = iris.target # 目标标签
|
||
|
||
# 2. 划分数据集(80%训练,20%测试)
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||
|
||
# 3. 特征标准化(很重要,KNN对特征尺度敏感)
|
||
scaler = StandardScaler()
|
||
X_train = scaler.fit_transform(X_train)
|
||
X_test = scaler.transform(X_test)
|
||
|
||
# 4. 构建KNN分类模型(这里K=3)
|
||
knn = KNeighborsClassifier(n_neighbors=3)
|
||
|
||
# 5. 训练模型
|
||
knn.fit(X_train, y_train)
|
||
|
||
# 6. 预测
|
||
y_pred = knn.predict(X_test)
|
||
|
||
# 7. 评估模型性能
|
||
print("Accuracy:", accuracy_score(y_test, y_pred))
|
||
print("Classification Report:")
|
||
print(classification_report(y_test, y_pred))
|
||
|
||
# 可视化(仅用于二维数据的情况)
|
||
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_pred, cmap='viridis', marker='o', label='Predictions')
|
||
plt.xlabel("Feature 1")
|
||
plt.ylabel("Feature 2")
|
||
plt.title("KNN Classification (Predictions)")
|
||
plt.savefig("./output/knn_c.png")
|