从 Keras 模型中获取每个批次/时期的 NN 权重

数据挖掘 Python 喀拉斯 张量流 梯度下降
2021-09-14 18:07:00

我试图在训练后从 Keras 模型中获取每个批次/时期的权重。为此,我使用回调使模型在训练期间保存权重。然而,在模型训练之后,看起来我只从最后一个时期获得权重。如何获得模型生成的所有权重?这是一个简单的例子:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import layers

# Generate data

start, stop = 1,100
cnt = stop - start + 1
xs = np.linspace(start, stop, num = cnt)
b,k = 1,2
ys = np.array([k*x + b for x in xs])

# Simple model with one feature and one unit for regression task

model = keras.Sequential([
    layers.Dense(units=1, input_shape=[1], activation='relu')
])
model.compile(loss='mae', optimizer='adam')
batch_size = int(cnt / 5)
epochs = 80

接下来是回调,以某种频率保存 Keras 模型权重。根据 Keras 文档:

save_freq:“纪元”或整数。使用 'epoch' 时,回调应在每个 epoch 后保存模型。使用整数时,回调应在这么多批次结束时保存模型。

checkpoint_filepath = './checkpoint.hdf5'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    save_freq ='epoch', # 1 for every batch
    save_best_only=False
)

# Train model

history = model.fit(xs, ys, batch_size=batch_size, epochs=epochs, 
                    callbacks=[model_checkpoint_callback])

我使用两种不同的方法来获得权重。第一的:

w, b = model.weights
print("Weights: \n {} \n Bias: \n {}".format(w,b))

Weights: 
 <tf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[-0.1450262]], dtype=float32)> 
 Bias: 
 <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>

这导致一个权重和一个偏差,而不是模型在每个批次/时期生成的所有权重。

第二种直接从 h5 文件中获取权重的方法:

# Functions to read weights from h5 file
import h5py

def getH5Keys(fileName):

    keys = []
    with h5py.File(fileName, mode='r') as f:
        for key in f:
            keys.append(key)

    return keys

def isGroup(obj):
    if isinstance(obj, h5py.Group):
        return True
    else:
        return False

def isDataset(obj):
    if isinstance(obj, h5py.Dataset):
        return True
    else:
        return False

def getDataSetsFromGroup(datasets, obj):
    if isGroup(obj):
        for key in obj:
            x = obj[key]
            getDataSetsFromGroup(datasets, x)
    else:
        datasets.append(obj)


def getWeightsForLayer(layerName, fileName):

    weights = []
    with h5py.File(fileName, mode='r') as f:
        for key in f:
            if layerName in key:
                obj = f[key]
                datasets = []
                getDataSetsFromGroup(datasets, obj)

                for dataset in datasets:
                    w = np.array(dataset)
                    weights.append(w)

    return weights

此方法为一个权重和一个偏差返回相同的奇异值:

layers = getH5Keys(checkpoint_filepath)
firstLayer = layers[0]
print(layers) # ['dense']

weights = getWeightsForLayer(firstLayer, checkpoint_filepath)
for w in weights:
    print(w.shape)
print(weights)

输出:

(1,)
(1, 1)
[array([0.], dtype=float32), array([[-0.1450262]], dtype=float32)]

同样,我只得到一个权重和一个偏差。如何获取模型为每个批次/时期生成的所有权重?

更新

10xAI为我工作的答案。但是,在我的情况下,我有一个具有一个单元的网络级别,因此我以不同的方式访问权重和偏差:

weights_dict = {}
weight_callback = tf.keras.callbacks.LambdaCallback \
( on_epoch_end=lambda epoch, logs:  weights_dict.update({epoch:model.get_weights()}))

# Train model
history = model.fit(xs, ys, batch_size=batch_size, epochs=epochs, 
                    callbacks=[weight_callback])

print(weights_dict[0])
Output: [array([[1.5375139]], dtype=float32), array([0.00499998], dtype=float32)]

print("*** Epoch: ", epoch, "\nWeight: ", weights_dict[0][0][0], " bias: ", weights_dict[1][0])
Output: *** Epoch:  79 
Weight:  [1.5375139]  bias:  [[1.5424858]]
2个回答

您可以使用 lambda 回调并将其保存在字典中。

weights_dict = {}

weight_callback = tf.keras.callbacks.LambdaCallback \
( on_epoch_end=lambda epoch, logs: weights_dict.update({epoch:model.get_weights()}))

history = model.fit( x_train, y_train, batch_size=16, epochs=5, callbacks=weight_callback )

# retrive weights
for epoch,weights in weights_dict.items():
    print("Weights for 2nd Layer of epoch #",epoch+1)
    print(weights[2])
    print("Bias for 2nd Layer of epoch #",epoch+1)
    print(weights[3])

在此处输入图像描述

您也可以为批次级别创建它。

Keras 文档表明文件名可以在文件名中包含纪元号要允许保留每组数据,请考虑使用类似以下的文件名:

checkpoint_filepath = './checkpoint-{epoch:02d}.hdf5'

filepath:字符串或PathLike,保存模型文件的路径。

  • filepath 可以包含命名的格式化选项,它将填充日志中的纪元和键的值(传入on_epoch_end)。例如:如果文件路径是weights.{epoch:02d}-{val_loss:.2f}.hdf5,那么模型检查点将与文件名中的纪元号和验证损失一起保存。