mechine_learning/svm_classification.py
2025-02-05 14:55:11 +08:00

34 lines
1.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix
# 加载数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 将数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 特征缩放对于SVM来说很重要
scaler = StandardScaler()
# fit_transform 计算数据集的参数,然后应用这是参数来转换数据
X_train = scaler.fit_transform(X_train)
# transform 使用前面fit学习到的参数来转换测试集参数.
X_test = scaler.transform(X_test)
# 创建SVM分类器
svm_classifier = SVC(kernel='linear') # 你可以尝试其他内核,如'rbf', 'poly'等
# 训练分类器
svm_classifier.fit(X_train, y_train)
# 对测试集进行预测
y_pred = svm_classifier.predict(X_test)
# 输出分类结果
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))