我有一个用于对肺部 CT 图像进行分类的 CNN 模型,代码是用 tensorflow 编写的,我在代码中添加了一些 tensorflow 摘要,以在 tensorboard 中显示我的 tensorflow 模型的图形、标量、直方图……,在我想要的最后一步将 add_summary 添加到文件编写器它给了我一个错误,这是我的代码:
def train_CNN(input):
train_predict = CNN_Model(x_img)
with tf.name_scope("cross_entropy"):
cost = tf.nn.softmax_cross_entropy_with_logits_v2(logits=train_predict, labels=y_label, name='cross_entropy')
cost = tf.reduce_mean(cost, name='reduce_mean')
tf.summary.scalar("cost", cost)
with tf.name_scope("optimization"):
optimizer = tf.train.AdamOptimizer(learning_rate, name='AdamOptimizer').minimize(cost)
with tf.name_scope("accuracy"):
correct_predict = tf.equal(tf.argmax(train_predict, 1), tf.argmax(y_label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))
tf.summary.scalar("accuracy", accuracy)
#tf.summary.image("input", x_img, 5)
sess.run(tf.global_variables_initializer())
log_train_path = 'C:/temp/tensorflow_logs' + '/train_{}'.format(datetime.now().strftime("%Y-%m-%d-%H%M%S"))
summary_writer = tf.summary.FileWriter(log_train_path)
summary_writer.add_graph(sess.graph)
merged_summary = tf.summary.merge_all()
all_time = 0
for epoch in range(num_epochs):
start_time = time.time()
ep_loss = 0
for data in train_data:
X = data[0]
Y = data[1]
summary, _, c = sess.run([merged_summary, optimizer, cost], feed_dict={x_img: X, y_label: Y})
ep_loss += c
summary_writer.add_summary(summary, epoch)
end_time = time.time()
all_time += int(end_time-start_time)
print('Epoch', epoch+1, 'completed out of',num_epochs,'loss:',ep_loss, 'time usage: '+str(int(end_time-start_time))+' seconds')
print('Accuracy of this epoch:',accuracy.eval({x_img:[i[0] for i in val_data], y_label:[i[1] for i in val_data]}))
print('Finall Accuracy:',accuracy.eval({x_img:[i[0] for i in val_data], y_label:[i[1] for i in val_data]}), 'time usage: '+str(all_time)+' seconds')
运行模型后它给我错误,谁能告诉我如何解决它?
这是错误:
InvalidArgumentError: Expected dimension in the range [-1, 1), but got 1
[[Node: accuracy/ArgMax_1 = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_y_label_0_1, accuracy/ArgMax_1/dimension)]]