如何从单个 TFRecords 文件向多输出 Keras 模型提供数据

数据挖掘 喀拉斯 多标签分类 多任务学习 多输出
2022-03-04 08:07:29

我知道如何使用 numpy 数组将数据提供给多输出 Keras 模型作为训练数据。但是,我将所有数据都放在一个TFRecords 文件中,该文件包含几个特征列:一个图像,用作 Keras 模型的输入,以及对应于不同分类任务的一系列输出:例如。一个输出编码图像中人的年龄,另一个输出编码性别,依此类推。

从我在示例中看到的情况来看,当模型的输出由各种头组成时,应该为模型提供多个数据源,一个用于输入,一个用于每个输出。

当数据都在一个 TFRecords 中时,有没有一种简单的方法可以做到这一点?我的意思是,不必为输入和每个输出创建单独的 TFRecords?

2个回答

在尝试了一些tf.data.map操作之后,我发现答案比预期的要容易,我只需要预处理数据并将模型的每个输出的所有标签作为字典的不同键。

首先,我从 tfrecords 文件创建一个数据集

dataset = tf.data.TFRecordDataset(tfrecords_file)

接下来,我从文件中解析数据

feature = {'image/encoded': tf.io.FixedLenFeature((), tf.string),
           'image/shape': tf.io.FixedLenFeature((3), tf.int64),
           'age': tf.io.FixedLenFeature((), tf.int64),
           'gender': tf.io.FixedLenFeature((), tf.int64),
           'ethnicity': tf.io.FixedLenFeature((), tf.int64),
 }

return tf_util.parse_pb_message(protobuff_message, feature)

dataset = dataset.map(parser).map(process_example)

在这一点上,我们有一个标准数据集,我们可以进行批处理、洗牌、扩充或任何我们想要的操作。最后,在将数据输入模型之前,我们必须对其进行转换以适应模型的要求。下面的代码显示了输入和标签预处理的示例。以前,我连接了所有标签,现在我创建了一个字典,其中模型中的输出名称作为键。

def preprocess_input_fn():
    def _preprocess_input(image,image_shape, age, gender, ethnicity):
        image = self.preprocess_image(image)
        labels = self.preprocess_labels(age, gender, ethnicity)
        return image, labels

    return _preprocess_input

def preprocess_image(image):
    image = tf.cast(image)
    image = tf.image.resize(image)
    image = (image / 127.5) - 1.0
    return image

def preprocess_labels(age,gender,ethnicity):
    gender = tf.one_hot(gender, 2)
    ethnicity = tf.one_hot(ethnicity, self.ethnic_groups)
    age = tf.one_hot(age, self.age_groups)
    return {'Gender': gender, 'Ethnicity': ethnicity, 'Age': age}

在我的模型中,Gender、Ethnicity 和 Age 是模型最后一层的名称,因此我的模型被定义为具有三个输出:

model = Model(inputs=inputs,
              outputs=[gender, ethnic_group, age_group])

现在我可以通过首先应用预处理函数来使用数据集来拟合模型:

data = dataset.map(preprocess_input_fn())

model.fit(data, epochs=...)    

考虑到您的模型将 animage作为输入并具有两个输出agegender,并且您已经使用它们生成了 TFRecord。tf.data您可以通过这种方式解码和使用您的 TFRecord :

decode_features = {
  'image'  : tf.io.FixedLenFeature([], tf.string),
  'age'    : tf.io.FixedLenFeature([1], tf.int64),
  'gender' : tf.io.FixedLenFeature([1], tf.int64),
}

def decode(serialized_example):
  features = tf.io.parse_single_example(serialized_example, features=decode_features)
  image = tf.image.decode_image(features['image_raw'], name="InputImage")
  image = tf.cast(image, tf.float32) / 128. - 1.
  labels = {}
  labels['age']    = tf.cast(features['age'], tf.int32)
  labels['gender'] = tf.cast(features['gender'], tf.int32)
  return image, labels

dataset = tf.data.TFRecordDataset('path/to/file.tfrecords')
dataset = dataset.map(decode)

model.fit(dataset, ...)
```