完成--svm回归示例

This commit is contained in:
haotian 2025-02-05 15:02:08 +08:00
parent f41920e20e
commit 0cd8099f2e
2 changed files with 41 additions and 0 deletions

BIN
output/svm_r.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

41
svm_regression.py Normal file
View File

@ -0,0 +1,41 @@
import numpy as np
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
# 生成一些示例数据
# 这里我们创建一个简单的非线性关系作为示例
np.random.seed(0)
X = np.sort(5 * np.random.rand(40, 1), axis=0)
y = np.sin(X).ravel()
# 为了使问题更具挑战性,我们向目标变量添加一些噪声
y[::5] += 3 * (0.5 - np.random.rand(8))
# 将数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建SVM回归模型
svr_rbf = SVR(kernel='rbf', C=100, gamma=0.1, epsilon=0.1)
# 训练模型
svr_rbf.fit(X_train, y_train)
# 对测试集进行预测
y_pred = svr_rbf.predict(X_test)
# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")
print(f"R^2 Score: {r2:.2f}")
# 可视化结果
plt.scatter(X, y, color='darkorange', label='data')
plt.plot(X_test, y_pred, color='navy', lw=2, label='SVR model')
plt.xlabel('data')
plt.ylabel('target')
plt.title('Support Vector Regression (SVR)')
plt.legend()
plt.savefig('./output/svm_r.png')