随机森林和预测

机器算法验证 随机森林 预言
2022-02-06 03:57:13

我试图了解随机森林是如何工作的。我对树木的构建方式有所了解,但无法理解随机森林如何对袋外样本进行预测。谁能给我一个简单的解释,好吗?:)

1个回答

森林中的每棵树都是从训练数据中观察的引导样本构建的。bootstrap 样本中的那些观察构建了树,而 bootstrap 样本中没有的那些形成了袋外(或 OOB)样本。

应该清楚的是,用于构建树的数据中的案例与 OOB 样本中的案例相同的变量可用。为了获得 OOB 样本的预测,每个样本都向下传递到当前树,并遵循树的规则,直到它到达终端节点。这会产生该特定树的 OOB 预测。

这个过程重复了很多次,每棵树都在来自训练数据的新引导样本上训练,并预测新的 OOB 样本。

随着树的数量增加,任何一个样本都将多次出现在 OOB 样本中,因此将样本在 OOB 中的 N 棵树的预测的“平均值”用作每个训练样本的 OOB 预测树 1, ..., N。通过“平均”,我们使用预测的平均值来表示连续响应,或者多数票可用于分类响应(多数票是在一组树 1, ..., N)。

例如,假设我们在 10 棵树的训练集中对 10 个样本有以下 OOB 预测

set.seed(123)
oob.p <- matrix(rpois(100, lambda = 4), ncol = 10)
colnames(oob.p) <- paste0("tree", seq_len(ncol(oob.p)))
rownames(oob.p) <- paste0("samp", seq_len(nrow(oob.p)))
oob.p[sample(length(oob.p), 50)] <- NA
oob.p

> oob.p
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA     7     8     2     1    NA     5     3      2
samp2      6    NA     5     7     3    NA    NA    NA    NA     NA
samp3      3    NA     5    NA    NA    NA     3     5    NA     NA
samp4      6    NA    10     6    NA    NA     3    NA     6     NA
samp5     NA     2    NA    NA     2    NA     6     4    NA     NA
samp6     NA     7    NA     4    NA     2     4     2    NA     NA
samp7     NA    NA    NA     5    NA    NA    NA     3     9      5
samp8      7     1     4    NA    NA     5     6    NA     7     NA
samp9      4    NA    NA     3    NA     7     6     3    NA     NA
samp10     4     8     2     2    NA    NA     4    NA    NA      4

其中NA意味着样本在该树的训练数据中(换句话说,它不在 OOB 样本中)。

每行的非值的平均值给出了整个森林NA的每个样本的 OOB 预测

> rowMeans(oob.p, na.rm = TRUE)
 samp1  samp2  samp3  samp4  samp5  samp6  samp7  samp8  samp9 samp10 
  4.00   5.25   4.00   6.20   3.50   3.80   5.50   5.00   4.60   4.00

随着每棵树被添加到森林中,我们可以计算 OOB 错误,直到包含该树。例如,以下是每个样本的累积平均值:

FUN <- function(x) {
  na <- is.na(x)
  cs <- cumsum(x[!na]) / seq_len(sum(!na))
  x[!na] <- cs
  x
}
t(apply(oob.p, 1, FUN))

> print(t(apply(oob.p, 1, FUN)), digits = 3)
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA  7.00  7.50  5.67  4.50    NA   4.6  4.33    4.0
samp2      6    NA  5.50  6.00  5.25    NA    NA    NA    NA     NA
samp3      3    NA  4.00    NA    NA    NA  3.67   4.0    NA     NA
samp4      6    NA  8.00  7.33    NA    NA  6.25    NA  6.20     NA
samp5     NA     2    NA    NA  2.00    NA  3.33   3.5    NA     NA
samp6     NA     7    NA  5.50    NA  4.33  4.25   3.8    NA     NA
samp7     NA    NA    NA  5.00    NA    NA    NA   4.0  5.67    5.5
samp8      7     4  4.00    NA    NA  4.25  4.60    NA  5.00     NA
samp9      4    NA    NA  3.50    NA  4.67  5.00   4.6    NA     NA
samp10     4     6  4.67  4.00    NA    NA  4.00    NA    NA    4.0

通过这种方式,我们可以看到预测是如何在森林中的 N 棵树上累积到给定迭代的。如果您跨行阅读,最右边的非NA值是我在上面显示的用于 OOB 预测的值。这就是可以跟踪 OOB 性能的方式 - 可以基于 N 树上累积的 OOB 预测为 OOB 样本计算 RMSEP。

请注意,显示的 R 代码并非取自 R 的randomForest包中的 randomForest 代码的内部 - 我只是敲了一些简单的代码,以便在确定每棵树的预测后您可以了解正在发生的事情。

这是因为每棵树都是从 bootstrap 样本构建的,并且在随机森林中有大量的树,因此每个训练集的观测值都在一个或多个树的 OOB 样本中,因此可以为所有的树提供 OOB 预测训练数据中的样本。

我已经忽略了一些 OOB 案例的数据缺失等问题,但这些问题也与单个回归或分类树有关。另请注意,森林中的每棵树仅使用mtry随机选择的变量。