如何训练机器,使其可以将“超出范围/类别”作为神经网络的输出

数据挖掘 神经网络 美国有线电视新闻网 训练 自动编码器
2022-02-14 07:08:01

我知道我无法正确说出问题的标题。所以我想在这里解释这个问题:假设,我构建并训练了一个 CNN 来识别从 0 到 9 的数字。但是,当我部署 CNN 时,有人给出了“#”作为输入(除了 0-9 之外的任何值)。在训练期间我可以对我的神经网络做些什么,以便它的输出可以说这不是训练字符?

我想提供第二个示例:假设我们想使用自动编码器进行降噪,而不是分类。同样,CNN 自动编码器被训练为 0 到 9。现在,我们如何准备它,以便如果有人给出“$”符号(它可以是不是 0-9 的任何东西)作为输入,它将能够确定这个符号不是它所训练的?CNN自动编码器将能够根据它给出输出吗?

2个回答

我认为最好的方法是增加一些数据并有一个额外的输出类“未知”。但是,如果这不可能或无法重新训练神经网络,我会比较隐藏层输出的分布。

对于下面的 CNN 架构,计算训练数据的 flatten 层之后的隐藏层输出的经验分布(例如n1)。保存此分布并在部署期间计算输出n1对于测试实例。如果输出n1对应于训练示例的经验分布的极不可能的值,返回“未知”。否则,预测此实例的数字。

在此处输入图像描述

图片来源

我想到了三个想法(从简单到复杂)

  1. 为任何不是数字的内容添加一个附加类别,并在这些类别上训练您的网络k+1类别。

  2. 首先应用另一个预测器,该预测器已经过训练以区分“数字”和“无数字”。如果输入被分类为数字,那么您就运行您的数字识别网络。(这种方法可能会使迁移学习更容易,即第一步应用现有模型)

  3. 将两个任务合并到一个网络中,使其成为一个多任务分类,即您的网络不仅包括识别数字的层,还包括“数字与无数字”的二进制分类。由于这些任务密切相关,因此这两个任务可能会受益于共享参数(即使用相同的特征)。论文An Overview of Multi-Task Learning in Deep Neural Networks更详细地描述了这种方法。(请注意,此方法与此列表中的第一个想法不同,因为该方法应用了两个单独的分类,而第一个仅进行了一个分类)

但是,作为免责声明:我自己没有尝试过第三个想法,而是将第二个想法用于类似的问题。