假设我已经在预训练网络上完成了迁移学习以识别 10 个对象。我怎样才能添加一个网络可以在不丢失我已经训练的所有 10 个类别以及来自原始预训练模型的信息的情况下分类的项目?一位朋友告诉我,该领域正在进行积极的研究,但我找不到任何相关的论文或名称可搜索?
如何为深度学习模型添加新类别?
如果这只是一次性案例,您可以简单地重新训练神经网络。如果您经常需要添加新类,那么这是一个坏主意。在这种情况下,您想要做的事情称为基于内容的图像检索 (CBIR),或者简称为图像检索或视觉搜索。我将在下面的回答中解释这两种情况。
一次性案例
如果这种情况只发生一次——你忘记了第 11 课,或者你的客户改变了主意——但它不会再次发生,那么你可以简单地将第 11 个输出节点添加到最后一层。随机初始化此节点的权重,但将您已有的权重用于其他输出。然后,像往常一样训练它。固定一些权重可能会有所帮助,即不要训练这些权重。
一个极端的情况是只训练新的权重,而其他所有的权重都保持不变。但我不确定这是否会运作良好 - 可能值得一试。
基于内容的图像检索
考虑以下示例:您正在为一家 CD 商店工作,该商店希望他们的客户能够为专辑封面拍照,并且该应用程序向他们展示了他们在其在线商店中扫描的 CD。在这种情况下,您必须针对商店中的每张新 CD 重新训练网络。这可能是每天 5 张新 CD,因此以这种方式重新训练网络是不合适的。
解决方案是训练一个网络,将图像映射到一个特征空间。每个图像将由一个描述符表示,例如一个 256 维向量。您可以通过计算此描述符来“分类”图像,并将其与您的描述符数据库(即您商店中所有 CD 的描述符)进行比较。数据库中最接近的描述符获胜。
你如何训练神经网络来学习这样的描述向量?这是一个活跃的研究领域。您可以通过搜索“图像检索”或“度量学习”等关键字找到最近的工作。
现在,人们通常采用预训练的网络,例如 VGG-16,切断 FC 层,并使用最终的卷积作为描述向量。您可以进一步训练该网络,例如使用具有三元组损失的连体网络。
这很容易做到。
首先用这 10 个类构建一个模型并将模型保存为 base_model。
加载 base_model 并定义一个名为 new_model 的新模型 -
new_model = Sequential()
然后将 base_model 的层添加到 new_model -
# getting all the layers except the last two layers
for layer in base_model.layers[:-2]: #just exclude the last two layers from base_model
new_model.add(layer)
现在使新模型的层不可训练,因为您不希望再次训练您的模型。
# prevent the already trained layers from being trained again
for layer in new_model.layers:
layer.trainable = False
现在,当您迁移学习时,当您删除最后一层时,模型会忘记 10 个类,因此我们必须将 base_model 的权重保留到 new_model -
weights_training = base_model.layers[-2].get_weights()
new_model.layers[-2].set_weights(weights_training)
现在在最后添加一个密集层,在这个例子中我们将只训练这个密集层。
new_model.add(Dense(CLASSES, name = 'new_Dense', activation = 'softmax'))
现在训练模型,我希望它为所有 11 个类提供正确的输出。
快乐学习。
您的网络拓扑可能看起来不同,但最终,您的预训练网络有一个层,用于处理 10 个原始类的识别。引入第 11、12.. n 类的最简单(且有效)的技巧是使用最后一层之前的所有层,并添加一个额外的层(在新模型中,或作为并行层),该层也将位于除了最后一层之外,除了最后一层之外,它看起来与 10class 层相似(这很可能是密集层的 matmul 和[len(dense layer), 10]
带有可选偏差的形状矩阵)。
您的新层将是一个 matmul 层 shape [len(dense layer), len(new classes)]
。
如果无法访问原始训练数据,您将有两种选择:
- 通过允许“新”模型仅优化新权重来冻结原始层中的所有权重。这将为您提供与原始 10 个类完全相同的预测能力,并可能为新类提供良好的性能。
- 一次训练整个网络(通过传播新类的错误),这可能适用于新类,但你最终会得到 10 个类的无效原始解决方案(因为较低类和最后一层的权重会改变不会更新以匹配这些更改)。
虽然,如果您可以访问原始训练数据,您可以轻松地将新类添加到原始网络并重新训练它以支持 11 个开箱即用的类。