为什么要下采样?

机器算法验证 机器学习 分类
2022-01-25 07:40:46

假设我想学习一个预测电子邮件是否为垃圾邮件的分类器。假设只有 1% 的电子邮件是垃圾邮件。

最简单的方法是学习一个简单的分类器,该分类器说所有电子邮件都不是垃圾邮件。这个分类器会给我们 99% 的准确率,但它不会学到任何有趣的东西,并且会有 100% 的假阴性率。

为了解决这个问题,人们告诉我要“下采样”,或者在数据子集上学习,其中 50% 的示例是垃圾邮件,50% 不是垃圾邮件。

但我担心这种方法,因为一旦我们构建了这个分类器并开始在真实的电子邮件语料库上使用它(而不是 50/50 测试集),它可能会预测很多电子邮件是垃圾邮件真的不是。只是因为它习惯于看到比数据集中实际更多的垃圾邮件。

那么我们如何解决这个问题呢?

(“上采样”或多次重复正训练示例,因此 50% 的数据是正训练示例,似乎遇到了类似的问题。)

4个回答

大多数分类模型实际上不会产生二元决策,而是产生连续决策值(例如,逻辑回归模型输出概率,SVM 输出到超平面的有符号距离,......)。使用决策值,我们可以对测试样本进行排序,从“几乎肯定是积极的”到“几乎肯定是消极的”。

根据决策值,您始终可以分配一些截止值,以这样一种方式配置分类器,即一定比例的数据被标记为正值。可以通过模型的ROC或 PR 曲线确定适当的阈值。无论训练集中使用的余额如何,您都可以使用决策阈值。换句话说,像上采样或下采样这样的技术与此正交。

假设模型优于随机模型,您可以直观地看到增加正分类的阈值(这会导致较少的正预测)以降低召回率为代价提高模型的精度,反之亦然。

将 SVM 视为一个直观的示例:主要挑战是学习分离超平面的方向上采样或下采样可以帮助解决这个问题(我建议更喜欢上采样而不是下采样)。当超平面的方向很好时,我们可以使用决策阈值(例如到超平面的有符号距离)来获得所需的正预测分数。

这里真正的问题是您对指标的选择:准确率百分比是衡量模型在不平衡数据集上成功的一个糟糕指标(正是因为您提到的原因:在这种情况下实现 99% 的准确率是微不足道的)。

在拟合模型之前平衡您的数据集是一个糟糕的解决方案,因为它会使您的模型产生偏差并且(更糟糕的是)会丢弃可能有用的数据。

你最好平衡你的准确性指标,而不是平衡你的数据。例如,您可以在评估模型时使用平衡的准确性(error for the positive class + error for the negative class)/2如果您预测全部为正面或全部为负面,则该指标将50%是一个很好的属性。

在我看来,下采样的唯一原因是当您有太多数据并且无法适合您的模型时。许多分类器(例如逻辑回归)在不平衡的数据上效果很好。

一如既往@Marc Claesen的好答案。

我只想补充一点,似乎缺少的关键概念是成本函数的概念。在任何模型中,您都有隐含或显式的误报成本(FN/FP)。对于所描述的不平衡数据,通常愿意采用 5:1 或 10:1 的比率。有许多方法可以将成本函数引入模型。一种传统的方法是对模型产生的概率施加一个概率截止值——这对于逻辑回归非常有效。

用于不自然地输出概率估计的严格分类器的一种方法是以会产生您感兴趣的成本函数的比率对多数类进行欠采样。请注意,如果您以 50/50 进行采样,您将引入任意成本函数。成本函数是不同的,但就像您以流行率采样一样随意。您通常可以预测与您的成本函数相对应的适当采样率(通常不是 50/50),但与我交谈过的大多数从业者只是尝试几个采样率并选择最接近其成本函数的一个。

直接回答 Jessica 的问题 - 下采样的一个原因是当您处理大型数据集并面临计算机内存限制或只是想减少处理时间时。从负例中进行下采样(即,在没有替换的情况下随机抽取样本)将数据集减小到更易于管理的大小。

您在问题中提到使用“分类器”,但没有具体说明是哪一个。您可能想要避免的一种分类器是决策树。在对罕见事件数据运行一个简单的决策树时,我经常发现该树只构建一个根,因为它很难将这么少的积极案例分成几类。可能有更复杂的方法可以提高树在稀有事件中的性能——我不知道有什么能在我脑海中浮现。

因此,正如 Marc Claesen 所建议的,使用返回连续预测概率值的逻辑回归是一种更好的方法。如果您对数据执行逻辑回归,则尽管记录较少,但系数仍保持无偏。你将不得不调整拦截,β0,根据 Hosmer 和 Lemeshow,2000 年的公式,从您的下采样回归中得出:

βc=β0log(p+1p+)

在哪里p+是您的预下采样人群中阳性病例的比例。

使用 ROC 查找您首选的垃圾邮件 ID 阈值可以通过首先使用在下采样数据集上转换的模型系数对完整数据集进行评分,然后从最高到最低的垃圾邮件预测概率对记录进行排序。接下来,顶一下n得分记录,其中n是您要设置的任何阈值(100、500、1000 等),然后计算顶部误报案例的百分比n病例和其余较低层中假阴性病例的百分比N-n案例,以便找到满足您需求的敏感性/特异性的适当平衡。