diff --git a/linear.py b/linear.py index 5018f74..aa658c6 100644 --- a/linear.py +++ b/linear.py @@ -1,5 +1,5 @@ ''' - 线性回归 + 线性回归, 用于回归任务, 通过线性方程预测连续值 ''' import numpy as np @@ -12,9 +12,14 @@ from sklearn.metrics import mean_squared_error, r2_score # 假设我们有一些数据点 (X, y),其中 X 是输入特征,y 是目标变量 # 设置随机种子, 不然每次运行程序结果都不一样 np.random.seed(0) + +# 生成形状为[100, 1] 值为[0, 1)均匀分布的数组 X = 2 * np.random.rand(100, 1) y = 4 + 3 * X + np.random.randn(100, 1) +print(X) +print(y) + # 将数据分为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)