寻找模型的最佳权重

数据挖掘 神经网络 逻辑回归 线性回归
2022-02-26 14:59:08

我正在尝试实现一种算法来找到函数的最小值。

在转向 sigmoid 激活函数之前,我试图了解线性回归。

通常,梯度下降算法用于找到算法收敛的最小值,但线性模型还有其他一些方法。

假设我有两个向量:

x=[1,2,3,4,5,6,7,8,9,10,11,12]

y=[2.3,2.33,2.29,2.3,2.36,2.4,2.46,2.5,2.48,2.43,2.38,2.35]

在此处输入图像描述

在这些点之间,我想添加一个带有最小二乘的线性分隔符。

假设我有一些不完美的线性函数:

f(x)=0.026x+2.3

据我所知,有两种方法可以找到它:

w=(XTX)1XTy

和梯度下降算法:

wnew=woldνdydx

尽管对于线性模型,求导数是微不足道的,因此不需要第二种方法。

现在我在 Python 中使用了向量的第一个方程:

w = ((np.transpose(x)*x)**-1)*np.transpose(x)*y

不幸的是,输出无关紧要:

[ 2.3, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]

然后我尝试在 Python 中使用第二种方法进行 500 次迭代:

for i in range(1,5000):
    x_old = x_new
    x_new = x_old - v*dydx
    print("x_new = {0} - {1}({2}) = {3}".format(x_old, v, dydx, x_new))

但是,我不确定如何知道它何时达到收敛点。

如何将这些方法正确用于线性模型?如果是这样,它们如何用于更复杂的模型,例如逻辑回归?

2个回答

在这种情况下,您的特征矩阵具有单一维度。图表中的每个点都有一个值,该值仅取决于的 1 个值。Xyx

好的,让我们看一下代码

x=[1,2,3,4,5,6,7,8,9,10,11,12]
y=[2.3,2.33,2.29,2.3,2.36,2.4,2.46,2.5,2.48,2.43,2.38,2.35]  

让我们将这些转换为矩阵。我们还将在矩阵的末尾添加一列 1。这将用于训练偏差值。X

temp = np.ones((len(x), 2))
temp[:,0] = np.asarray(x)
x = temp
y = np.asarray(y)

现在我们将计算权重为

w=(XTX)1XTy

w = np.matmul(np.matmul(np.linalg.inv(np.matmul(np.transpose(x), x)), np.transpose(x)), y)

数组([0.01174825,2.30530303])

查看我们的权重向量的维度。它只有 2 个值。与相关的一个值x,我们的第一列X矩阵和偏差,与我们添加的 1 的列相关联。这条线的方程描述为

y=0.01174825x1+2.30530303

在此处输入图像描述

我们可以看到这条线确实很好地描述了线性回归的数据。

更深的

但是,您的数据看起来更适合使用多项式。你应该试试

y=w1x12+w2x1+b

为此,在X对应的矩阵x2.

x=[1,2,3,4,5,6,7,8,9,10,11,12]
y=[2.3,2.33,2.29,2.3,2.36,2.4,2.46,2.5,2.48,2.43,2.38,2.35] 
temp = np.ones((len(x), 3))
temp[:,0] = np.power(np.asarray(x), 2)
temp[:,1] = np.asarray(x)
x = temp
y = np.asarray(y)

w = np.matmul(np.matmul(np.linalg.inv(np.matmul(np.transpose(x), x)), np.transpose(x)), y)

[![xx = range(1,15,1)
yy = \[0\]*len(xx)
for ix, i in enumerate(xx):
    yy\[ix\] = w\[0\]*i**2 + w\[1\]*i + w\[2\]][2]][2]

在此处输入图像描述

更深

并通过添加更进一步x3术语与我们得到的方式相同

在此处输入图像描述

确保不要给你的多项式添加太高的次数,否则你会过度拟合!这意味着尽管您完美地描述了您的训练数据,但它不能很好地推广到新实例。因此,这将是一个无用的模型。这就是为什么您需要拆分训练和测试数据,这样您就可以验证使用训练数据构建的模型是否可以泛化。

你也误解了-1..它不是指数而是矩阵逆的符号

.T用于转置......(有点Pythonic方便)

只是为了帮助你,搜索np.linalg.inv

对于您的第二个查询,

参考这里

再补充一点,当你看到你的损失没有改善或者正在改善时,你就停下来5th decimal place..

希望这可以帮助..