我有以下代码,其中输入数据中的 x,w 是权重矩阵,bv 和 bh 是可见和隐藏单元的偏差。
import theano.tensor as T
x_states = x > numpy.random.rand(training_examples, feats)
hid = T.nnet.sigmoid(T.dot(x_states, w) + bh) # Activation function of hidden layer
hid_states = hid > numpy.random.rand(training_examples, nhidden)
# Construct Theano expression graph
vis = T.nnet.sigmoid(T.dot(hid_states, w.T) + by)
vis_states = vis > numpy.random.rand(training_examples, feats)
hid2 = T.nnet.sigmoid(T.dot(vis_states, w) + bh)
hid2_states = hid2 > numpy.random.rand(training_examples, nhidden)
xent = T.sum((x - vis)**2)
parameters = [w, bh, by] # this line defines all parameters for each layer
cost = xent.mean()
pos_associations = T.dot(x.T, hid)
neg_associations = T.dot(vis.T, hid2)
wparameters = [w]
byparameters = [by]
bhparameters = [bh]
update = []
for wparam, byparam, bhparam in zip(wparameters, byparameters, bhparameters):
update.append((wparam, wparam + tr_rate * (pos_associations - neg_associations)))
update.append((byparam, byparam + tr_rate*(T.sum(x.T[:]) - T.sum(vis.T[:]))))
update.append((bhparam, bhparam + tr_rate*(T.sum(hid.T[:]) - T.sum(hid2.T[:]))))
train = theano.function(
inputs=[x],
outputs=[cost],
updates=update)
更新的正确实现是什么?在我的示例中,这显然不起作用。我在经典的 MNIST 数据集上对此进行了测试。
这是我付出代价得到的结果。如您所见,它并没有减少。怎么了?
Epoch: 0
cost = [ 969572.73014003]
Epoch: 1
cost = [ 258872.77507019]
Epoch: 2
cost = [ 258872.77507019]
Epoch: 3
cost = [ 258872.77507019]
Epoch: 4
cost = [ 258872.77507019]
Epoch: 5
cost = [ 2003326.79850769]
Epoch: 6
cost = [ 258872.77507019]
Epoch: 7
cost = [ 258872.77507019]
Epoch: 8
cost = [ 258872.77507019]
Epoch: 9
cost = [ 258872.77507019]
Epoch: 10
cost = [ 258872.77507019]
Epoch: 11
cost = [ 258872.77507019]
Epoch: 12
cost = [ 2003326.79850769]
Epoch: 13
cost = [ 258872.77507019]
Epoch: 14
cost = [ 258872.77507019]
Epoch: 15
cost = [ 258872.77507019]
Epoch: 16
cost = [ 258872.77507019]
Epoch: 17
cost = [ 258872.77507019]
Epoch: 18
cost = [ 258872.77507019]
Epoch: 19
cost = [ 258872.77507019]
Epoch: 20
cost = [ 2003326.79850769]
Epoch: 21
cost = [ 258872.77507019]
Epoch: 22
cost = [ 258872.77507019]
Epoch: 23
cost = [ 258872.77507019]
Epoch: 24
cost = [ 258872.77507019]
Epoch: 25
cost = [ 258872.77507019]
Epoch: 26
cost = [ 258872.77507019]
Epoch: 27
cost = [ 258872.77507019]
Epoch: 28
cost = [ 2003326.79850769]
Epoch: 29
cost = [ 258872.77507019]
Epoch: 30
cost = [ 258872.77507019]
Epoch: 31
cost = [ 258872.77507019]
Epoch: 32
cost = [ 258872.77507019]
Epoch: 33
cost = [ 258872.77507019]
Epoch: 34
cost = [ 258872.77507019]
Epoch: 35
cost = [ 258872.77507019]
Epoch: 36
cost = [ 2003326.79850769]
Epoch: 37
cost = [ 258872.77507019]
Epoch: 38
cost = [ 258872.77507019]
Epoch: 39
cost = [ 258872.77507019]
Epoch: 40
cost = [ 258872.77507019]
Epoch: 41
cost = [ 258872.77507019]
Epoch: 42
cost = [ 258872.77507019]
Epoch: 43
cost = [ 2003326.79850769]
Epoch: 44
cost = [ 258872.77507019]
Epoch: 45
cost = [ 258872.77507019]
Epoch: 46
cost = [ 258872.77507019]
Epoch: 47
cost = [ 258872.77507019]
Epoch: 48
cost = [ 258872.77507019]
Epoch: 49
cost = [ 258872.77507019]
Epoch: 50
cost = [ 258872.77507019]
Epoch: 51
cost = [ 2003326.79850769]
Epoch: 52
cost = [ 258872.77507019]
Epoch: 53
cost = [ 258872.77507019]
Epoch: 54
cost = [ 258872.77507019]
Epoch: 55
cost = [ 258872.77507019]
Epoch: 56
cost = [ 258872.77507019]
Epoch: 57
cost = [ 258872.77507019]
Epoch: 58
cost = [ 258872.77507019]
Epoch: 59
cost = [ 2003326.79850769]
Epoch: 60
cost = [ 258872.77507019]
Epoch: 61
cost = [ 258872.77507019]
Epoch: 62
cost = [ 258872.77507019]
Epoch: 63
cost = [ 258872.77507019]
Epoch: 64
cost = [ 258872.77507019]
Epoch: 65
cost = [ 258872.77507019]
Epoch: 66
cost = [ 258872.77507019]
Epoch: 67
cost = [ 2003326.79850769]
Epoch: 68
cost = [ 258872.77507019]
Epoch: 69
cost = [ 258872.77507019]
Epoch: 70
cost = [ 258872.77507019]
Epoch: 71
cost = [ 258872.77507019]
Epoch: 72
cost = [ 258872.77507019]
Epoch: 73
cost = [ 258872.77507019]
Epoch: 74
cost = [ 2003326.79850769]
Epoch: 75
cost = [ 258872.77507019]
Epoch: 76
cost = [ 258872.77507019]
Epoch: 77
cost = [ 258872.77507019]
Epoch: 78
cost = [ 258872.77507019]
Epoch: 79
cost = [ 258872.77507019]
Epoch: 80
cost = [ 258872.77507019]
Epoch: 81
cost = [ 258872.77507019]
Epoch: 82
cost = [ 2003326.79850769]
Epoch: 83
cost = [ 258872.77507019]
Epoch: 84
cost = [ 258872.77507019]
Epoch: 85
cost = [ 258872.77507019]
Epoch: 86
cost = [ 258872.77507019]
Epoch: 87
cost = [ 258872.77507019]
Epoch: 88
cost = [ 258872.77507019]
Epoch: 89
cost = [ 258872.77507019]
Epoch: 90
cost = [ 2003326.79850769]
Epoch: 91
cost = [ 258872.77507019]
Epoch: 92
cost = [ 258872.77507019]
Epoch: 93
cost = [ 258872.77507019]
Epoch: 94
cost = [ 258872.77507019]
Epoch: 95
cost = [ 258872.77507019]
Epoch: 96
cost = [ 258872.77507019]
Epoch: 97
cost = [ 258872.77507019]
Epoch: 98
cost = [ 2003326.79850769]
Epoch: 99
cost = [ 258872.77507019]