在玻尔兹曼机中学习权重

机器算法验证 神经网络
2022-03-18 22:30:36

我试图了解玻尔兹曼机器的工作原理,但我不太确定权重是如何学习的,也无法找到清晰的描述。以下是正确的吗?(此外,任何好的玻尔兹曼机器解释的指针也很好。)

我们有一组可见单元(例如,对应于图像中的黑色/白色像素)和一组隐藏单元。权重以某种方式初始化(​​例如,从 [-0.5, 0.5] 统一),然后我们在以下两个阶段之间交替,直到达到某个停止规则:

  1. 钳位阶段——在这个阶段,所有可见单元的值都是固定的,所以我们只更新隐藏单元的状态(根据玻尔兹曼随机激活规则)。我们更新直到网络达到平衡。一旦达到平衡,我们将继续更新N更多次(对于一些预定义的N),跟踪平均值xixj(在哪里xi,xj是节点的状态ij)。在那些之后N均衡更新,我们更新wij=wij+1CAverage(xixj), 在哪里C是一些学习率。(或者,不是在最后进行批量更新,而是在平衡步骤之后更新?)

  2. 自由阶段——在这个阶段,所有单元的状态都会更新。一旦达到平衡,我们同样会继续更新 N' 次,但不是在最后添加相关性,而是减去:wij=wij1CAverage(xixj).

所以我的主要问题是:

  1. 每当我们处于钳制阶段时,我们是将可见单元重置为我们想要学习的模式之一(以某种频率表示该模式的重要性),还是让可见单元保持它们所处的状态在自由阶段结束时?

  2. 我们是在每个阶段结束时对权重进行批量更新,还是在阶段的每个平衡步骤更新权重?(或者,任何一个都可以吗?)

3个回答

直观地说,您可以将可见单元视为“模型所见”,将隐藏单元视为“模型的心理状态”。当您将所有可见单位设置为某些值时,您“将数据显示给模型”。然后,当你激活隐藏单元时,模型会根据看到的情况调整其心理状态。

接下来,您让模型自由发挥并进行幻想。它会变得封闭起来,从字面上看到它的思维产生的一些东西,并根据这些图像产生新的思维状态。

我们通过调整权重(和偏差)所做的就是让模型更多地相信数据,而不是相信它自己的幻想。这样,经过一些训练后,它会相信一些(希望如此)非常好的数据模型,例如,我们可以问“你相信这对 (X,Y) 吗?你找到它的可能性有多大?先生,你有什么看法?”玻尔兹曼机?”

最后这里是对基于能量的模型的简要描述,它应该让您直观地了解 Clamped 和 Free 阶段的来源以及我们希望如何运行它们。

http://deeplearning.net/tutorial/rbm.html#energy-based-models-ebm

有趣的是,直观清晰的更新规则来自模型生成数据的对数似然推导。

考虑到这些直觉,现在可以更轻松地回答您的问题:

  1. 我们必须将可见单位重置为我们希望模型相信的一些数据。如果我们使用自由阶段结束时的值,它只会继续幻想,最终强制执行它自己被误导的信念。

  2. 最好在阶段结束后进行更新。特别是如果是钳制阶段,最好给模型一些时间来“专注”于数据。较早的更新会减慢收敛速度,因为它们会在模型尚未将其心态调整为现实时强制连接。在幻想的每个平衡步骤之后更新重量应该会减少伤害,尽管我没有这方面的经验。

如果您想提高您对 EBM、BM 和 RBM 的直觉,我建议您观看 Geoffrey Hinton 的一些关于该主题的讲座,他有一些很好的类比。

  1. 是的,“我们将可见单元重置(钳制)为我们想要学习的模式之一(以某种频率表示该模式的重要性)。”

  2. 是的,“我们在每个阶段结束时对权重进行批量更新。” 我不认为更新“阶段中每个平衡步骤的权重”会导致快速收敛,因为网络会被瞬时错误“分散注意力”——我已经以这种方式实现了玻尔兹曼机,我记得它运行得不是很好直到我将其更改为批量更新。

这是基于 Paul Ivanov 的代码的玻尔兹曼机器示例 Python 代码

http://redwood.berkeley.edu/wiki/VS265:_Homework_assignments

import numpy as np

def extract_patches(im,SZ,n):
    imsize,imsize=im.shape;
    X=np.zeros((n,SZ**2),dtype=np.int8);
    startsx= np.random.randint(imsize-SZ,size=n)
    startsy=np.random.randint(imsize-SZ,size=n)
    for i,stx,sty in zip(xrange(n), startsx,startsy):
        P=im[sty:sty+SZ, stx:stx+SZ];
        X[i]=2*P.flat[:]-1;
    return X.T

def sample(T,b,n,num_init_samples):
    """
    sample.m - sample states from model distribution

    function S = sample(T,b,n, num_init_samples)

    T:                weight matrix
    b:                bias
    n:                number of samples
    num_init_samples: number of initial Gibbs sweeps
    """
    N=T.shape[0]

    # initialize state vector for sampling
    s=2*(np.random.rand(N)<sigmoid(b))-1

    for k in xrange(num_init_samples):
        s=draw(s,T,b)

    # sample states
    S=np.zeros((N,n))
    S[:,0]=s
    for i in xrange(1,n):
        S[:,i]=draw(S[:,i-1],T,b)

    return S

def sigmoid(u):
    """
    sigmoid.m - sigmoid function

    function s = sigmoid(u)
    """
    return 1./(1.+np.exp(-u));

def draw(Sin,T,b):
    """
    draw.m - perform single Gibbs sweep to draw a sample from distribution

    function S = draw(Sin,T,b)

    Sin:      initial state
    T:        weight matrix
    b:        bias
    """
    N=Sin.shape[0]
    S=Sin.copy()
    rand = np.random.rand(N,1)
    for i in xrange(N):
        h=np.dot(T[i,:],S)+b[i];
        S[i]=2*(rand[i]<sigmoid(h))-1;

    return S

def run(im, T=None, b=None, display=True,N=4,num_trials=100,batch_size=100,num_init_samples=10,eta=0.1):
    SZ=np.sqrt(N);
    if T is None: T=np.zeros((N,N)); # weight matrix
    if b is None: b=np.zeros(N); # bias

    for t in xrange(num_trials):
        print t, num_trials
        # data statistics (clamped)
        X=extract_patches(im,SZ,batch_size).astype(np.float);
        R_data=np.dot(X,X.T)/batch_size;
        mu_data=X.mean(1);

        # prior statistics (unclamped)
        S=sample(T,b,batch_size,num_init_samples);
        R_prior=np.dot(S,S.T)/batch_size;
        mu_prior=S.mean(1);

        # update params
        deltaT=eta*(R_data - R_prior);
        T=T+deltaT;

        deltab=eta*(mu_data - mu_prior);
        b=b+deltab;


    return T, b

if __name__ == "__main__": 
    A = np.array([\
    [0.,1.,1.,0],
    [1.,1.,0, 0],
    [1.,1.,1.,0],
    [0, 1.,1.,1.],
    [0, 0, 1.,0]
    ])
    T,b = run(A,display=False)
    print T
    print b

它通过创建数据补丁来工作,但是可以对其进行修改,以便代码始终适用于所有数据。