修改--线性回归示例--添加注释和保存图片

This commit is contained in:
haotian 2025-02-05 11:53:03 +08:00
parent f2af7c49f3
commit 877b061a91

View File

@ -1,5 +1,5 @@
'''
线性回归
线性回归, 用于回归任务, 通过线性方程预测连续值
'''
import numpy as np
@ -12,9 +12,14 @@ from sklearn.metrics import mean_squared_error, r2_score
# 假设我们有一些数据点 (X, y),其中 X 是输入特征y 是目标变量
# 设置随机种子, 不然每次运行程序结果都不一样
np.random.seed(0)
# 生成形状为[100, 1] 值为[0, 1)均匀分布的数组
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
print(X)
print(y)
# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)