我阅读了一些关于神经网络反向传播的教程,并决定从头开始实现一个。在过去的几天里,我试图在我的代码中找到这个单一的错误,但没有成功。
我遵循本教程希望能够实现正弦函数逼近器。这是一个简单的网络:1 个输入神经元、10 个隐藏神经元和 1 个输出神经元。第二层的激活函数是 sigmoid。完全相同的模型很容易在 Tensorflow 中工作。
def sigmoid(x):
return 1 / (1 + np.math.e ** -x)
def sigmoid_deriv(x):
return sigmoid(x) * (1 - sigmoid(x))
x_data = np.random.rand(500) * 15.0
y_data = [sin(x) for x in x_data]
ETA = .01
layer1 = 0
layer1_weights = np.random.rand(10) * 2. - 1.
layer2 = np.zeros(10)
layer2_weights = np.random.rand(10) * 2. - 1.
layer3 = 0
for loop_iter in range(500000):
# data init
index = np.random.randint(0, 500)
x = x_data[index]
y = y_data[index]
# forward propagation
# layer 1
layer1 = x
# layer 2
layer2 = layer1_weights * layer1
# layer 3
layer3 = sum(sigmoid(layer2) * layer2_weights)
# error
error = .5 * (layer3 - y) ** 2 # L2 loss
# backpropagation
# error_wrt_layer3 * layer3_wrt_weights_layer2
error_wrt_layer2_weights = (y - layer3) * sigmoid(layer2)
# error_wrt_layer3 * layer3_wrt_out_layer2 * out_layer2_wrt_in_layer2 * in_layer2_wrt_weights_layer1
error_wrt_layer1_weights = (y - layer3) * layer2_weights * sigmoid_deriv(sigmoid(layer2)) * layer1
# update the weights
layer2_weights -= ETA * error_wrt_layer2_weights
layer1_weights -= ETA * error_wrt_layer1_weights
if loop_iter % 10000 == 0:
print(error)
出乎意料的行为只是网络没有收敛。请查看我的 error_wrt_... 衍生品。问题应该在那里。
这是它完美运行的 TensorFlow 代码:
x_data = np.array(np.random.rand(500)).reshape(500, 1)
y_data = np.array([sin(x) for x in x_data]).reshape(500, 1)
x = tf.placeholder(tf.float32, shape=[None, 1])
y_true = tf.placeholder(tf.float32, shape=[None, 1])
W = tf.Variable(tf.random_uniform([1, 10], -1.0, 1.0))
hidden1 = tf.nn.sigmoid(tf.matmul(x, W))
W_hidden = tf.Variable(tf.random_uniform([10, 1], -1.0, 1.0))
output = tf.matmul(hidden1, W_hidden)
loss = tf.square(output - y_true) / 2.
optimizer = tf.train.GradientDescentOptimizer(.01)
train = optimizer.minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(500000):
rand_index = np.random.randint(0, 500)
_, error = sess.run([train, loss], feed_dict={x: [x_data[rand_index]],
y_true: [y_data[rand_index]]})
if i % 10000 == 0:
print(error)
sess.close()