RBM 的对比分歧

数据挖掘 神经网络 深度学习 西阿诺 RBM
2022-02-23 10:55:06

我有以下代码,其中输入数据中的 x,w 是权重矩阵,bv 和 bh 是可见和隐藏单元的偏差。

import theano.tensor as T

x_states = x > numpy.random.rand(training_examples, feats)

hid =  T.nnet.sigmoid(T.dot(x_states, w) + bh)    # Activation function of hidden layer
hid_states = hid > numpy.random.rand(training_examples, nhidden)

# Construct Theano expression graph
vis = T.nnet.sigmoid(T.dot(hid_states, w.T) + by) 
vis_states = vis > numpy.random.rand(training_examples, feats)

hid2 = T.nnet.sigmoid(T.dot(vis_states, w) + bh)
hid2_states = hid2 > numpy.random.rand(training_examples, nhidden)

xent = T.sum((x - vis)**2)
parameters = [w, bh, by]  # this line defines all parameters for each layer
cost = xent.mean()

pos_associations = T.dot(x.T, hid)
neg_associations = T.dot(vis.T, hid2)  

wparameters = [w]
byparameters = [by]
bhparameters = [bh]

update = []
for wparam, byparam, bhparam in zip(wparameters, byparameters, bhparameters):
    update.append((wparam, wparam + tr_rate * (pos_associations - neg_associations)))
    update.append((byparam, byparam + tr_rate*(T.sum(x.T[:]) - T.sum(vis.T[:]))))      
    update.append((bhparam, bhparam + tr_rate*(T.sum(hid.T[:]) - T.sum(hid2.T[:]))))


train = theano.function(
          inputs=[x],
          outputs=[cost],
          updates=update)

更新的正确实现是什么?在我的示例中,这显然不起作用。我在经典的 MNIST 数据集上对此进行了测试。

这是我付出代价得到的结果。如您所见,它并没有减少。怎么了?

Epoch: 0
cost =  [ 969572.73014003]
Epoch: 1
cost =  [ 258872.77507019]
Epoch: 2
cost =  [ 258872.77507019]
Epoch: 3
cost =  [ 258872.77507019]
Epoch: 4
cost =  [ 258872.77507019]
Epoch: 5
cost =  [ 2003326.79850769]
Epoch: 6
cost =  [ 258872.77507019]
Epoch: 7
cost =  [ 258872.77507019]
Epoch: 8
cost =  [ 258872.77507019]
Epoch: 9
cost =  [ 258872.77507019]
Epoch: 10
cost =  [ 258872.77507019]
Epoch: 11
cost =  [ 258872.77507019]
Epoch: 12
cost =  [ 2003326.79850769]
Epoch: 13
cost =  [ 258872.77507019]
Epoch: 14
cost =  [ 258872.77507019]
Epoch: 15
cost =  [ 258872.77507019]
Epoch: 16
cost =  [ 258872.77507019]
Epoch: 17
cost =  [ 258872.77507019]
Epoch: 18
cost =  [ 258872.77507019]
Epoch: 19
cost =  [ 258872.77507019]
Epoch: 20
cost =  [ 2003326.79850769]
Epoch: 21
cost =  [ 258872.77507019]
Epoch: 22
cost =  [ 258872.77507019]
Epoch: 23
cost =  [ 258872.77507019]
Epoch: 24
cost =  [ 258872.77507019]
Epoch: 25
cost =  [ 258872.77507019]
Epoch: 26
cost =  [ 258872.77507019]
Epoch: 27
cost =  [ 258872.77507019]
Epoch: 28
cost =  [ 2003326.79850769]
Epoch: 29
cost =  [ 258872.77507019]
Epoch: 30
cost =  [ 258872.77507019]
Epoch: 31
cost =  [ 258872.77507019]
Epoch: 32
cost =  [ 258872.77507019]
Epoch: 33
cost =  [ 258872.77507019]
Epoch: 34
cost =  [ 258872.77507019]
Epoch: 35
cost =  [ 258872.77507019]
Epoch: 36
cost =  [ 2003326.79850769]
Epoch: 37
cost =  [ 258872.77507019]
Epoch: 38
cost =  [ 258872.77507019]
Epoch: 39
cost =  [ 258872.77507019]
Epoch: 40
cost =  [ 258872.77507019]
Epoch: 41
cost =  [ 258872.77507019]
Epoch: 42
cost =  [ 258872.77507019]
Epoch: 43
cost =  [ 2003326.79850769]
Epoch: 44
cost =  [ 258872.77507019]
Epoch: 45
cost =  [ 258872.77507019]
Epoch: 46
cost =  [ 258872.77507019]
Epoch: 47
cost =  [ 258872.77507019]
Epoch: 48
cost =  [ 258872.77507019]
Epoch: 49
cost =  [ 258872.77507019]
Epoch: 50
cost =  [ 258872.77507019]
Epoch: 51
cost =  [ 2003326.79850769]
Epoch: 52
cost =  [ 258872.77507019]
Epoch: 53
cost =  [ 258872.77507019]
Epoch: 54
cost =  [ 258872.77507019]
Epoch: 55
cost =  [ 258872.77507019]
Epoch: 56
cost =  [ 258872.77507019]
Epoch: 57
cost =  [ 258872.77507019]
Epoch: 58
cost =  [ 258872.77507019]
Epoch: 59
cost =  [ 2003326.79850769]
Epoch: 60
cost =  [ 258872.77507019]
Epoch: 61
cost =  [ 258872.77507019]
Epoch: 62
cost =  [ 258872.77507019]
Epoch: 63
cost =  [ 258872.77507019]
Epoch: 64
cost =  [ 258872.77507019]
Epoch: 65
cost =  [ 258872.77507019]
Epoch: 66
cost =  [ 258872.77507019]
Epoch: 67
cost =  [ 2003326.79850769]
Epoch: 68
cost =  [ 258872.77507019]
Epoch: 69
cost =  [ 258872.77507019]
Epoch: 70
cost =  [ 258872.77507019]
Epoch: 71
cost =  [ 258872.77507019]
Epoch: 72
cost =  [ 258872.77507019]
Epoch: 73
cost =  [ 258872.77507019]
Epoch: 74
cost =  [ 2003326.79850769]
Epoch: 75
cost =  [ 258872.77507019]
Epoch: 76
cost =  [ 258872.77507019]
Epoch: 77
cost =  [ 258872.77507019]
Epoch: 78
cost =  [ 258872.77507019]
Epoch: 79
cost =  [ 258872.77507019]
Epoch: 80
cost =  [ 258872.77507019]
Epoch: 81
cost =  [ 258872.77507019]
Epoch: 82
cost =  [ 2003326.79850769]
Epoch: 83
cost =  [ 258872.77507019]
Epoch: 84
cost =  [ 258872.77507019]
Epoch: 85
cost =  [ 258872.77507019]
Epoch: 86
cost =  [ 258872.77507019]
Epoch: 87
cost =  [ 258872.77507019]
Epoch: 88
cost =  [ 258872.77507019]
Epoch: 89
cost =  [ 258872.77507019]
Epoch: 90
cost =  [ 2003326.79850769]
Epoch: 91
cost =  [ 258872.77507019]
Epoch: 92
cost =  [ 258872.77507019]
Epoch: 93
cost =  [ 258872.77507019]
Epoch: 94
cost =  [ 258872.77507019]
Epoch: 95
cost =  [ 258872.77507019]
Epoch: 96
cost =  [ 258872.77507019]
Epoch: 97
cost =  [ 258872.77507019]
Epoch: 98
cost =  [ 2003326.79850769]
Epoch: 99
cost =  [ 258872.77507019]
0个回答
没有发现任何回复~