在尝试了一些tf.data.map操作之后,我发现答案比预期的要容易,我只需要预处理数据并将模型的每个输出的所有标签作为字典的不同键。
首先,我从 tfrecords 文件创建一个数据集
dataset = tf.data.TFRecordDataset(tfrecords_file)
接下来,我从文件中解析数据
feature = {'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/shape': tf.io.FixedLenFeature((3), tf.int64),
'age': tf.io.FixedLenFeature((), tf.int64),
'gender': tf.io.FixedLenFeature((), tf.int64),
'ethnicity': tf.io.FixedLenFeature((), tf.int64),
}
return tf_util.parse_pb_message(protobuff_message, feature)
dataset = dataset.map(parser).map(process_example)
在这一点上,我们有一个标准数据集,我们可以进行批处理、洗牌、扩充或任何我们想要的操作。最后,在将数据输入模型之前,我们必须对其进行转换以适应模型的要求。下面的代码显示了输入和标签预处理的示例。以前,我连接了所有标签,现在我创建了一个字典,其中模型中的输出名称作为键。
def preprocess_input_fn():
def _preprocess_input(image,image_shape, age, gender, ethnicity):
image = self.preprocess_image(image)
labels = self.preprocess_labels(age, gender, ethnicity)
return image, labels
return _preprocess_input
def preprocess_image(image):
image = tf.cast(image)
image = tf.image.resize(image)
image = (image / 127.5) - 1.0
return image
def preprocess_labels(age,gender,ethnicity):
gender = tf.one_hot(gender, 2)
ethnicity = tf.one_hot(ethnicity, self.ethnic_groups)
age = tf.one_hot(age, self.age_groups)
return {'Gender': gender, 'Ethnicity': ethnicity, 'Age': age}
在我的模型中,Gender、Ethnicity 和 Age 是模型最后一层的名称,因此我的模型被定义为具有三个输出:
model = Model(inputs=inputs,
outputs=[gender, ethnic_group, age_group])
现在我可以通过首先应用预处理函数来使用数据集来拟合模型:
data = dataset.map(preprocess_input_fn())
model.fit(data, epochs=...)