用于复杂数据的 TensorFlow 数据集批处理

数据挖掘 张量流
2022-01-20 23:55:58

我尝试按照此链接中的示例进行操作:

https://www.tensorflow.org/programmers_guide/datasets

但我完全不知道如何运行会话。我理解第一个参数是要运行的操作,而 feed_dict 是占位符(我的理解是训练或测试数据集的批次),

所以,这是我的代码:

batch_size = 100
handle_mix = tf.placeholder(tf.float64, shape=[])
handle_src0 = tf.placeholder(tf.float64, shape=[])
handle_src1 = tf.placeholder(tf.float64, shape=[])
handle_src2 = tf.placeholder(tf.float64, shape=[])
handle_src3 = tf.placeholder(tf.float64, shape=[])

我从 mp4 音轨和词干创建数据集,读取混合和源幅度,并将它们填充以适合批处理

dataset = tf.data.Dataset.from_tensor_slices(
    {"x_mixed":padded_lbl, "y_src0": padded_src[0], "y_src1":   padded_src[1],"y_src2": padded_src[1], "y_src3": padded_src[1]})

dataset = dataset.shuffle(1000).repeat().batch(batch_size)

iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

从我应该做的例子中:

next_element = iterator.get_next()

training_init_op = iterator.make_initializer(dataset)
for _ in range(20):
    # Initialize an iterator over the training dataset.
    sess.run(training_init_op)
    for _ in range(100):
        sess.run(next_element)

但是,我有一个损失、摘要和优化器操作,需要将数据作为批处理提供,下面是另一个示例:

l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0,
                             handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

所以我想像:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = data.train.next_batch(batch_size)

或者可能是单独运行以先获取批次,然后按上述方式运行优化,例如:

batch_mix, batch_src0, batch_src1, batch_src2, batch_src3 = sess.run(next_element)
l, _, summary = sess.run([loss_fn, optimizer, summary_op], feed_dict={handle_mix: batch_mix, handle_src0: batch_src0, handle_src1: batch_src1, handle_src2: batch_src2, handle_src3: batch_src3})

最后一次尝试返回了在 tf.data.Dataset.from_tensor_slices 中创建的批次的字符串名称(“x_mixed”、“y_src0”、...等),并且未能在会话中转换为 tf.float64 占位符。

你能告诉我如何创建这个数据集吗,首先张量切片的结构可能有错误,然后如何对它们进行批处理,

非常感谢你,

马纳尔

1个回答

根据我从您的代码中可以理解的内容,您似乎需要使用可初始化的迭代器。原因如下:

  • 您正在从占位符创建数据集。

这是我的解决方案:

batch_size = 100
handle_mix = tf.placeholder(tf.float64, shape=[])
handle_src0 = tf.placeholder(tf.float64, shape=[])
handle_src1 = tf.placeholder(tf.float64, shape=[])
handle_src2 = tf.placeholder(tf.float64, shape=[])
handle_src3 = tf.placeholder(tf.float64, shape=[])

将一个元组传递 .from_tensor_slicestf.data.Dataset类的方法

dataset = tf.data.Dataset.from_tensor_slices((handle_mix, handle_src0,
                                               handle_src1, handle_src2,
                                               handle_src3))

dataset = dataset.shuffle(1000).repeat().batch(batch_size)

iter = dataset.make_initializable_iterator()

# unpack five values since dataset was created from five placeholders    
a, b, c, d, e = iter.get_next()

创建一个会话对象并初始化迭代器,确保将“您的数据”传递给占位符。“你的数据”可以是 numpy 数组

sess = tf.Session()
sess.run(iter.initializer, feed_dict={handle_mix:'your data',
         handle_src0:'your data',handle_src1:'your data',
         handle_src2:'your data',handle_src3:'your data'})

运行以下命令以打印迭代器的内容

print "%r, %r, %r, %r, %r" % sess.run([a,b,c,d,e])

TensorFlow 官方文档中提供了有关如何使用数据集 API 的相关详细信息或说明。