在 R 中可视化具有 6 个预测变量的线性模型

机器算法验证 r 回归 数据可视化
2022-03-24 12:04:20

我最近一直想知道如何可视化具有 >1 预测器的线性模型。今天我看到一篇论文,它设法在散点图中可视化了一个包含 6 个预测变量(和一个响应变量)的线性模型。

该模型与此类似:

体重 ~ 自 9 月 1 日起的天数 + 地点 + 年龄 + 年份 + PC1 + 到达日期

体重是一个连续的响应变量。自 9 月 1 日以来的天数是一个整数预测因子,站点是一个因子预测因子(2 个水平),年龄是一个因子预测因子(2 个水平),年份是一个因子预测因子(2 个水平),PC1 是一个连续预测因子,到达日期是一个整数预测器。 

散点图的 y 轴是模型预测的重量值。散点图的 x 轴是自 9 月 1 日以来的天数。八条最佳拟合线显示了三个因子预测变量每个水平的预测值。PC1 和到达日期“保持其平均值不变”。 

情节看起来像这样(使用 Photoshop 制作):

在此处输入图像描述

我希望有人可以使用安装在 R 中的数据集重新创建与此类似的图?我对如何保持两个预测变量“在它们的平均值上保持不变”特别感兴趣,以实现其他预测变量的可视化。mtcars 数据集可能是合适的。

1个回答

这是一些希望不言自明的代码:

set.seed(20987)     # for reproducability

N = 200

  # variables
days_since   = rpois(N, lambda=60)
site         = factor(sample(c("site1", "site2"), N, replace=T), c("site1", "site2"))
age          = factor(sample(c("juv", "adult"),   N, replace=T), c("juv", "adult"))
year         = factor(sample(c("2012", "2013"),   N, replace=T), c("2012", "2013"))
PC1          = rnorm(N, mean=100, sd=25)
arrival_date = sample.int(365, N, replace=T)

  # betas
B0  =  13
Bds =  74
Bs  = 114
Ba  = 160
By  = 191
Bpc =  59
Bad =  11

  # response variable
weight = B0 + Bds*days_since + Bs*(site=="site2") + Ba*(age=="adult") + 
         By*(year=="2013") + Bpc*PC1 + Bad*arrival_date + rnorm(N, mean=0, sd=10)

model = lm(weight~days_since+site+age+year+PC1+arrival_date)

  # predicted values for plot
ds    = seq(min(days_since), max(days_since))
ds1j2 = predict(model, data.frame(days_since=ds, site="site1", age="juv",   
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1j3 = predict(model, data.frame(days_since=ds, site="site1", age="juv",   
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1a2 = predict(model, data.frame(days_since=ds, site="site1", age="adult", 
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds1a3 = predict(model, data.frame(days_since=ds, site="site1", age="adult", 
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2j2 = predict(model, data.frame(days_since=ds, site="site2", age="juv",   
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2j3 = predict(model, data.frame(days_since=ds, site="site2", age="juv",   
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2a2 = predict(model, data.frame(days_since=ds, site="site2", age="adult", 
                       year="2012", PC1=mean(PC1), arrival_date=mean(arrival_date)))
ds2a3 = predict(model, data.frame(days_since=ds, site="site2", age="adult", 
                       year="2013", PC1=mean(PC1), arrival_date=mean(arrival_date)))

  # plot
windows()
  plot(x=ds, y=ds1j2, ylim=c(11000, 14500), type="l", lty=1,
       ylab="predicted weight", xlab="days since 1st Sept")
                                points(range(ds), range(ds1j2), pch=5)
  lines(x=ds, y=ds1j3, lty=1);  points(range(ds), range(ds1j3), pch=18)
  lines(x=ds, y=ds1a2, lty=2);  points(range(ds), range(ds1a2), pch=5)
  lines(x=ds, y=ds1a3, lty=2);  points(range(ds), range(ds1a3), pch=18)
  lines(x=ds, y=ds2j2, lty=1);  points(range(ds), range(ds2j2), pch=1)
  lines(x=ds, y=ds2j3, lty=1);  points(range(ds), range(ds2j3), pch=16)
  lines(x=ds, y=ds2a2, lty=2);  points(range(ds), range(ds2a2), pch=1)
  lines(x=ds, y=ds2a3, lty=2);  points(range(ds), range(ds2a3), pch=16)

  legend("bottomright", lty=rep(1:2, 4), pch=c(5,18,5,18,1,16,1,16), 
         legend=c("site 1, juveniles, 2012", "site 1, juveniles, 2013", 
                  "site 1, adults,    2012", "site 1, adults,    2013", 
                  "site 2, juveniles, 2012", "site 2, juveniles, 2013", 
                  "site 2, adults,    2012", "site 2, adults,    2013")) 

您可以编写更短的代码,方法是编写将读取列表并为您完成所有这些操作的函数,而不是连续八次复制和粘贴相同的内容,但这应该更容易理解。这是情节:

在此处输入图像描述

当有交互(线不平行)时,这种情节更有趣/有用。在这种情况下,我们只有一组八行,它们相对于彼此垂直移动。