如何自动测试线性模型中转换自变量的最佳参数

数据挖掘 机器学习 回归 特征选择 梯度下降
2021-10-05 16:12:09

假设我有一个线性模型k变量:

y=β0+β1x1++βkxk.

现在,我想添加变量xk+1,但是,根据领域知识,yxk+1不是线性的,而是“S形”。为了捕捉这种依赖关系,一个成熟的方法是使用参数化的 arcus tangens 函数:Darctanxk+1AB+C. 要将其包含在线性模型中,我们可以放心地忽略DC参数(因为它们将有助于βk+1β0分别),所以最终我会得到模型:

y=β0+β1x1++βkxk+βk+1arctanxk+1AB.

我想做的是找到一种方法来测试不同的转换(使用不同的AB参数自动。首先,我去搜索不同的网格AB值,即将模型拟合到不同的变换xk+1并返回模型列表,按R2,RMSE或者AIC. 然而,这通常有利于非常尖锐的曲线:

在此处输入图像描述

不过,这在视力测试中没有多大意义。我知道眼睛测试可能有偏差,但对我来说,更可靠的依赖性将被以下形状捕获:

在此处输入图像描述

人们可能已经注意到,有很多观察结果xk+1是或接近于0,这实际上是那些将被近似的变量的特征arctan. 这可能会影响拟合,但这里仍然有些东西显然不起作用。

我考虑的另一种方法是使用梯度下降搜索而不是使用标准软件(如lm来自 R)来拟合模型。问题是,有了这个arctan变换,我们不再有凸目标函数。

然后我怎样才能自动测试的最佳参数arctan转型?我的想法是否正确,或者我可以以不同的方式解决这个问题吗?

4个回答

一种方法是使用为非凸问题(如贝叶斯优化)设计的算法。但是,如果您已经评估了一个精细的参数网格,那么这不太可能提供显着的改进。这是一个示例,说明如何针对此问题实施贝叶斯优化。

首先,我们需要一些数据。只是为了好玩,让我们从您发布的图像中提取数据(简而言之,因为这是题外话)。

在数学中:

img = Import[NotebookDirectory[]<>"LHrXQ.png"]
img2 = ImageResize[ImageTake[img, {40, 450}, {50, 1000}], 200]
pixels = {(#[[1]]-10)*3.8,#[[2]]*9+100}&/@PixelValuePositions[img2,Black, 0.4];

ListPlot[pixels, Frame->True, ImageSize->500]

在此处输入图像描述

现在为了使用贝叶斯优化,我们需要将目标定义为参数 A 和 B 的函数。这里我们将最大化拟合模型的 R^2 值。

在蟒蛇中:

import pandas as pd
import numpy as np
import random
from sklearn.linear_model import LinearRegression

data = pd.read_csv('mathematica_data.csv')

def objective(params):
    """Whatever you want to do for your regression."""

    # So it works with GpyOpt
    A = params[0][0]
    B = params[0][1]

    temp = data.copy()

    # Transform variable
    xt = [np.arctan((x - A)/B) for x in data['x1'].tolist()]
    temp['x1'] = xt

    # Fit a linear model
    reg = LinearRegression().fit(temp.drop('y', axis=1), temp['y'])

    # Compute scores of interest
    r2 = reg.score(temp.drop('y', axis=1), temp['y'])

    # GPyOpt will minimize so we want -f
    return - r2

现在使用 GPyOpt 优化目标。

import GPyOpt

domain = [{'name': 'A', 'type': 'continuous', 'domain': (-300.0,300.0)},
          {'name': 'B', 'type': 'continuous', 'domain': (0.1,300.0)}]


bo = GPyOpt.methods.BayesianOptimization(f=objective,
                                         domain=domain,
                                         model_type='GP',
                                         acquisition_type='EI',
                                         initial_design_numdata=10,
                                         initial_design_type='random',
                                         acquisition_jitter=0.01,
                                         num_cores=-1,
                                         de_duplication=True,
                                         exact_feval=True)

# Run optimization
bo.run_optimization(max_iter=100)

我们可以绘制优化器收敛:

bo.plot_convergence()

在此处输入图像描述

以及有关如何对参数进行采样的信息:

bo.plot_acquisition() 

在此处输入图像描述

当然,仅使用从绘图中提取的自变量,即使是最好的 R^2 值也表明 arctan((xA)/B) 之间几乎没有关系。

您可以构建优化问题。

您的特征将是参数,然后您创建一个损失函数并尝试使用梯度下降或蛮力将其最小化(取决于您的问题的搜索空间)

我不确定您是否受制于问题中提出的模型类型。然而,另一种选择是使用广义加法模型(GAM),例如回归样条或区域回归。这些方法通常可以很好地拟合非线性模式X 并且不需要提供参数化 X 这样就很容易找到好的模型。

这是一个基于模拟数据的示例(R代码):

# Generate data
x <- -50:100
y <- 0.001*x^3
plot(x,y)
df = data.frame(y,x)

# Linear regression
reg_ols=lm(y~.,data=df)
pred_ols = predict(reg_ols, newdata=df)

# GAM with regression splined (df=3)
library(gam)
reg_gam = gam(y~s(x,3), data=df)
pred_gam = predict(reg_gam, newdata=df)

# Find opt. number of splines
library(Metrics)
for (sp in seq(1:50)){
  gamx=gam(y~s(x,sp), data=df)
  print(mse(y, predict(gamx, newdata=df)))
}

# Plot prediction and actual data
require(ggplot2)
df2 = data.frame(x,y,pred_ols, pred_gam)
ggplot(df2, aes(x)) +                    
  geom_line(aes(y=y),size=1, colour="red") +  
  geom_line(aes(y=pred_ols),size=1, colour="blue") +
  geom_line(aes(y=pred_gam),size=1, colour="black", linetype = "dashed")

如您所见,该模型非常适合我的非线性函数,而无需提供参数化(见图)。有关应用介绍,请参阅ISL 第 7.7 章。

在此处输入图像描述

您可以使用非线性最小二乘法,其中一个回归量是您的反正切函数,还有两个要估计的参数。

在 R 中,例如:

library(minpack.lm)

df <- datasets::airquality

my_atan <- function(x, A, B){atan((x-A)/B)}

nlsLM(Ozone ~ a + b * Temp + c * my_atan(Temp, A, B),
      data = df,
      start = list(a = 0, b = 0, c = 0, A = 0, B = 1))