我正在使用以下代码导入一堆 .png 图像并使用 TensorFlow 对其进行解码:
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import numpy as np
import os
tf.enable_eager_execution()
NUM_TRAINING_SAMPLES = 333
NUM_CLASSES = 3
BATCH_SIZE = 5
NUM_EPOCHS = 6
INPUT_SIZE = (256, 256, 3)
random_indices = np.random.choice(range(13000), NUM_TRAINING_SAMPLES)
directory = "/home/local/CYCLOMEDIA001/ebos/Downloads/SYNTHIA_RAND_CVPR16"
directory_images = "/home/Downloads/SYNTHIA_RAND_CVPR16/RGB"
directory_labels = "/home/Downloads/SYNTHIA_RAND_CVPR16/GT"
train_images = np.array(os.listdir(directory_images))
train_labels = np.array(os.listdir(directory_images))
train_images = train_images[random_indices]
train_labels = train_labels[random_indices]
train_images = [tf.read_file(os.path.join(directory_images, img)) for img in train_images]
train_labels = [tf.read_file(os.path.join(directory_labels, img)) for img in train_labels]
train_images = [tf.io.decode_image(img, channels=3) for img in train_images]
train_labels = [tf.io.decode_image(img, channels=3) for img in train_labels]
train_images = tf.image.resize_images(train_images, INPUT_SIZE[:2])
train_labels = tf.image.resize_images(train_labels, INPUT_SIZE[:2])
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = train_dataset.batch(3)
print(train_dataset.output_types)
这将返回:
(tf.float32, tf.float32)
但是,根据文档,它应该返回一个 uint8 或 uint16 的张量。为什么以及在哪里进行转换?
我用打印语句检查了所有中间步骤,这并没有告诉我太多,因为大多数中间列表都属于“tensorflow.python.framework.ops.EagerTensor”类。