一个简单的逻辑回归模型如何在 MNIST 上实现 92% 的分类准确率?

机器算法验证 物流 图像处理
2022-01-22 00:25:18

尽管 MNIST 数据集中的所有图像都是居中的,具有相似的比例,并且没有旋转,但它们具有显着的笔迹变化,这让我感到困惑,线性模型如何实现如此高的分类精度。

据我所知,考虑到显着的手写变化,数字在 784 维空间中应该是线性不可分的,即应该有一点复杂(虽然不是很复杂)的非线性边界来分隔不同的数字,类似于被广泛引用的XOR任何线性分类器都不能区分正类和负类的例子。多类逻辑回归如何在完全线性特征(没有多项式特征)的情况下产生如此高的准确度,这让我感到莫名其妙。

例如,给定图像中的任何像素,数字的不同手写变化23可以使该像素发光或不发光。因此,通过一组学习的权重,每个像素可以使一个数字看起来像2以及ASA3. 只有通过像素值的组合才能判断一个数字是否是2或一个3. 大多数数字对都是如此。那么,逻辑回归如何盲目地将其决策独立地基于所有像素值(根本不考虑任何像素间依赖关系),能够实现如此高的精度。

我知道我在某个地方错了,或者只是高估了图像的变化。但是,如果有人可以帮助我直观地了解数字如何“几乎”线性分离,那就太好了。

1个回答

tl;dr即使这是一个图像分类数据集,它仍然是一项非常简单的任务,人们可以轻松地找到从输入到预测的直接映射。


回答:

这是一个非常有趣的问题,并且由于逻辑回归的简单性,您实际上可以找到答案。

逻辑回归所做的是对每个图像接受784输入并将它们与权重相乘以生成其预测。有趣的是,由于输入和输出之间的直接映射(即没有隐藏层),每个权重的值对应于每个权重的多少784在计算每个类别的概率时会考虑输入。现在,通过获取每个类的权重并将它们重塑为28×28(即图像分辨率),我们可以知道哪些像素对每个类的计算最重要

再次注意,这些是权重

现在看一下上面的图像并关注前两位数(即零和一)。蓝色权重意味着该像素的强度对该类的贡献很大,红色值意味着它的贡献是负面的。

现在想象一下,一个人如何画一个0? 他画了一个圆形,中间是空的。这正是重量增加的原因。事实上,如果有人画了图像的中间,它就会被视为负数为零。因此,要识别零,您不需要一些复杂的过滤器和高级功能。您可以只看绘制的像素位置并据此进行判断。

同样的事情1. 它总是在图像中间有一条垂直的直线。其他的都是负数。

其余的数字稍微复杂一些,但你可以想象一下2, 这3, 这78. 其余的数字有点困难,这实际上限制了逻辑回归达到 90 年代的高位。

通过这一点,您可以看到逻辑回归很有可能获得很多正确的图像,这就是它得分如此之高的原因。


重现上图的代码有点过时了,但在这里:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)