使用 TensorFlow-Slim 从现有检查点微调模型

数据挖掘 深度学习 张量流 成立之初
2022-02-14 20:13:13

我正在尝试使用TensorFlow-Slim使用新的图像数据集重新训练预训练模型的最后一层

假设我想在鲜花数据集上微调 inception-v3。Inception_v3 在具有 1000 个类别标签的 ImageNet 上进行训练,但鲜花数据集只有 5 个类别。由于数据集非常小,我们将只训练新层。

官方 tf github 页面上的示例显示了如何执行此操作:

$ DATASET_DIR=/tmp/flowers
$ TRAIN_DIR=/tmp/flowers-models/inception_v3
$ CHECKPOINT_PATH=/tmp/my_checkpoints/inception_v3.ckpt
$ python train_image_classifier.py \
    --train_dir=${TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --dataset_name=flowers \
    --dataset_split_name=train \
    --model_name=inception_v3 \
    --checkpoint_path=${CHECKPOINT_PATH} \
    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits \
    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits

我无法完全理解上述代码中的所有参数:

train_dir= ?

dataset_dir= 新数据集目录位置

dataset_name= 数据集的名称(但为什么

dataset_split_name= ?

model_name=我们要训练的模型的名称

checkpoint_path= 模型检查点的路径

checkpoint_exclude_scopes= ?

trainable_scopes= ?


如果我对任何参数有误,请帮我弄清楚这些参数的含义并纠正我?

注意:我知道我们可以使用tensorflow 官方网站上提到的方法重新训练 inception_v3,但我想对 tensorflow-slim 做同样的事情。

2个回答

您可以从源代码中获取您需要的所有信息。

  • train_dir:写入检查点和事件日志的目录。(默认:/tmp/tfmodel/)

  • dataset_dir:存储数据集文件的目录。

  • dataset_name:要加载的数据集的名称。(默认:imagenet)

  • dataset_split_name:训练/测试拆分的名称。(默认:火车)

  • model_name:要训练的架构的名称。(默认:inception_v3)

  • checkpoint_path:从其微调的检查点的路径。

  • checkpoint_exclude_scopes:从检查点恢复时要排除的变量范围的逗号分隔列表。

  • trainable_scopes:以逗号分隔的范围列表,用于过滤要训练的变量集。默认情况下, None 将训练所有变量。

除了@Icyblade 通过引用源代码提出的建议之外,我还想添加一些内容。

  • dataset_name是一个选项,告诉 slim 如何读取文件夹中的 TFRecords dataset_dir
    默认值imagenet意味着 TFRecords 必须具有以下格式:train-00146-of-00168.tfrecordvalidation-00003-of-00019.tfrecord.
    该值flowers意味着 slim 期望 TFRecords 具有形式flowers_train_00146-of-00168.tfrecordflowers_validation_00003-of-00019.tfrecord这种格式化 tfrecords 的特殊方式在源代码中定义。
    不同的样本数据集有不同的格式化规则,例如在这里你可以看到 cifar10
  • 此外,dataset_split_name告诉转换器如何命名 TFRecords 文件。对于flowers 数据集,它们是trainand validation,对于 cifar10 它们是trainand test,等等。
  • 关于checkpoint_exclude_scopestrainable_scopes,长话短说,这些命令告诉训练算法从最后一个可用检查点中删除网络的最后两层(通过checkpoint_path选项传递或从 中指定的文件夹中读取train_dir),并使其重新训练CNN 对它们进行处理,但使用您在 TFRecords 文件中提供的图像。因此,使用这些选项意味着对另一组图像执行 CNN 的重新训练——或者,正如苗条的开发人员在 README 中所说,从现有检查点(评论中的链接)进行微调。

为了给你一个关于如何构建你自己的 TFRecord 转换器/创建器的提示,我建议它基于 download_and_convert_flowers.py 脚本,因为输入数据集是常规的,图像分为类别,每个类别都是一个文件夹。
此外,我会通过添加读取输入 png 图像的选项来扩展它(鲜花图片只有 jpg,因此脚本不需要这种转换,但你未来的数据集可能)。首先将其逐个分解并删除与鲜花数据集相关的硬编码信息,以使其不那么具体。