(theanets这里的主要作者。)正如您对问题的评论所暗示的,这实际上是一个难以学习的问题!正如训练期间优化的损失值所表明的那样,网络正在尽可能地学习表示这组输入数据的最佳过滤器。
这里要考虑的重要一点是,网络中的权重被调整为代表整个输入空间,而不仅仅是一个输入。我的猜测是您希望网络学习一个高斯 blob 特征,但这不是它的工作原理。
从网络的角度来看,它被要求代表一个从这个数据池中任意采样的输入。下一个样本中的哪些像素将为零?哪些将是非零的?网络不知道,因为输入用零和非零像素平铺整个像素空间。一组均匀填充空间的数据的最佳表示是一堆或多或少均匀分布的小值,这就是您所看到的。
相比之下,尝试将输入数据限制为高斯斑点的子集。让我们将它们全部放在像素的对角条纹中,例如:
import climate
import matplotlib.pyplot as plt
import numpy as np
import skimage.filters
import theanets
climate.enable_default_logging()
def gen_inputs(x=28, sigma=2.0):
return np.array([
skimage.filters.gaussian_filter(i, sigma).astype('f')
for i in (np.eye(x*x)*2).reshape(x, x, x*x).transpose()
]).reshape(x*x, x*x)[10::27]
data = gen_inputs()
plt.imshow(data.mean(axis=0).reshape((28, 28)))
plt.show()
net = theanets.Autoencoder([784, 9, 784])
net.train(data, weight_l2=0.0001)
w = net.find('hid1', 'w').get_value().T
img = np.zeros((3 * 28, 3 * 28), float)
for r in range(3):
for c in range(3):
img[r*28:(r+1)*28, c*28:(c+1)*28] = w[r*3+c].reshape((28, 28))
plt.imshow(img)
plt.show()
这是平均数据的图(imshow代码中的第一个):

这是学习特征的图(第二个imshow):

这些特征响应整个数据集的平均值!
如果您想让网络学习更多“个人”功能,这可能会非常棘手。你可以玩的东西:
- 按照评论中的建议,增加隐藏单元的数量。
hidden_l1=0.5尝试对隐藏单元激活 ( )使用 L1 惩罚进行训练。
- 尝试强制权重本身是稀疏的 (
weight_l1=0.5)。
祝你好运!