训练非MNIST数据时准确性停滞不前

数据挖掘 Python 逻辑回归 张量流
2022-03-10 06:31:14

我是机器学习的初学者。我在 Python 中构建了一个逻辑分类器,使用 TensorFlow 在notMNIST dataset上进行训练。我的代码是这样的:

weights = tf.Variable(tf.truncated_normal(shape = [784, 10]))
bias = tf.Variable(tf.zeros(shape = [10]))
logits = tf.matmul(features, weights) + bias
prediction = tf.nn.softmax(logits)
cross_entropy = -tf.reduce_sum(labels * tf.log(prediction), reduction_indices=1)
loss = tf.reduce_mean(cross_entropy)

train_feed_dict = {features: train_features, labels: train_labels}
valid_feed_dict = {features: valid_features, labels: valid_labels}
test_feed_dict = {features: test_features, labels: test_labels}

is_correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct_prediction, tf.float32))

epochs = 5
batch_size = 50
learning_rate = 0.1

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

validation_accuracy = 0.0

with tf.Session() as session:

    session.run(tf.global_variables_initializer())
    batch_count = int(math.ceil(len(train_features)/batch_size))

    for epoch_i in range(epochs):

        for i in range(batch_count):

            session.run(optimizer, feed_dict = train_feed_dict)
            print(session.run(accuracy, feed_dict = train_feed_dict))

然而,问题在于,虽然训练损失不断减少,但准确性最初会波动,然后最终停滞不前(在 0.062 左右)。我无法理解代码有什么问题。任何帮助,将不胜感激。谢谢。

1个回答

使用tf.softmax_cross_entropy_with_logits函数而不是自己编写它,因为它在tf.softmax_cross_entropy_with_logits计算上更稳定。

试试这个, loss = tf.reduce_mean(tf.softmax_cross_entropy_with_logits(logits))

参考:https ://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits