diff --git a/svm_classification.py b/svm_classification.py new file mode 100644 index 0000000..c4df440 --- /dev/null +++ b/svm_classification.py @@ -0,0 +1,34 @@ +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.metrics import classification_report, confusion_matrix + +# 加载数据集 +iris = datasets.load_iris() +X = iris.data +y = iris.target + +# 将数据集拆分为训练集和测试集 +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + +# 特征缩放(对于SVM来说很重要) +scaler = StandardScaler() + +# fit_transform 计算数据集的参数,然后应用这是参数来转换数据 +X_train = scaler.fit_transform(X_train) +# transform 使用前面fit学习到的参数来转换测试集参数. +X_test = scaler.transform(X_test) + +# 创建SVM分类器 +svm_classifier = SVC(kernel='linear') # 你可以尝试其他内核,如'rbf', 'poly'等 + +# 训练分类器 +svm_classifier.fit(X_train, y_train) + +# 对测试集进行预测 +y_pred = svm_classifier.predict(X_test) + +# 输出分类结果 +print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred)) +print("\nClassification Report:\n", classification_report(y_test, y_pred)) \ No newline at end of file