如何使用 Keras 模型实现我自己的原型学习损失函数

数据挖掘 Python 喀拉斯 张量流 损失函数 元学习
2022-02-28 16:10:44

我正在尝试将此代码“使用原型网络的 Omniglot 字符集分类”迁移到 Tensorflow 2.1.0 和 Keras 2.3.1。

我的问题是关于如何在训练数据和验证数据之间使用欧几里得距离。看看这段代码:

def convolution_block(inputs, out_channels, name='conv'):

    conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
    conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
    conv = tf.nn.relu(conv)
    conv = tf.contrib.layers.max_pool2d(conv, 2)

    return conv


def get_embeddings(support_set, h_dim, z_dim, reuse=False):

    net = convolution_block(support_set, h_dim)
    net = convolution_block(net, h_dim)
    net = convolution_block(net, h_dim) 
    net = convolution_block(net, z_dim) 
    net = tf.contrib.layers.flatten(net)

    return net



support_set_embeddings = get_embeddings(tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim)

embedding_dimension = tf.shape(support_set_embeddings)[-1]

class_prototype = tf.reduce_mean(tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1)

query_set_embeddings = get_embeddings(tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True)


def euclidean_distance(a, b):

    N, D = tf.shape(a)[0], tf.shape(a)[1]
    M = tf.shape(b)[0]
    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
    return tf.reduce_mean(tf.square(a - b), axis=2)


distance = euclidean_distance(query_set_embeddings,class_prototype)

predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1])

loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1]))
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y)))

train = tf.train.AdamOptimizer().minimize(loss)

如果我正确理解了所有内容,它会从support_set(又名训练数据)中获取嵌入,并从query_set(又名验证数据)中获取嵌入。计算 中所有嵌入的平均值support_set,因为它们都来自同一类。然后,它使用这个平均值来计算嵌入query_set与这个平均值(又名class_prototype)之间的距离。

所以,如果我想使用 VGG16 作为get_embeddings函数。换句话说,我将使用它来获取support_setand的嵌入query_set

def vgg16_feature_extractor(input_size = (200,200,1)):
    inputs = Input(input_size, name = 'input')

    conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', name ='conv1_1')(inputs)
    conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', name ='conv1_2')(conv1)
    pool1 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_1')(conv1)

    conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', name ='conv2_1')(pool1)
    conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', name ='conv2_2')(conv2)
    pool2 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_2')(conv2)

    conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_1')(pool2)
    conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_2')(conv3)
    conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_3')(conv3)
    pool3 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_3')(conv3)

    conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_1')(pool3)
    conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_2')(conv4)
    conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_3')(conv4)
    pool4 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_4')(conv4)

    conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_1')(pool4)
    conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_2')(conv5)
    conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_3')(conv5)
    pool5 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_5')(conv5)

    model = Model(inputs = inputs, outputs = pool5, name = 'vgg-16_feature_extractor')

    return model

然后在train.py

model = vgg16_feature_extractor(input_size = (200,200,1))
model.compile(optimizer=opt, loss=my_own_loss_function, metrics=['accuracy'])
model.fit(...)

我不知道如何实现my_own_loss_function,因为这个函数只有两个参数y_true, y_pred,并且必须使用嵌入和嵌入y_pred之间的欧几里德距离来计算support_setquery_set

我必须如何实现my_own_loss_function才能根据需要使用它?

也许,y_true是来自的嵌入,support_set并且y_pred是来自的嵌入query_set

1个回答

好吧,有几种方法可以做到这一点。

一个相当强大的解决方案是定义一个 pred 层

class PredLayer(Layer):
    """
        Layer object to calculate distance between query_embeddings and supposrt embeddings.
    """
    def __init__(self, **kwargs):
        super(PredLayer, self).__init__(**kwargs)

    def euclidean_distance(self, inputs):
        """
            Euclidean square distance.
        """
        support, query = inputs
        output = K.mean(K.square(support - query), axis=-1)
        output = K.expand_dims(output, 1)
        return output

    def call(self, inputs):
        y_pred = self.euclidean_distance(inputs)
        return y_pred

因此,您必须构建您的 Keras 网络,以便您的支持和查询嵌入将成为该层的输入。

...

model_pred = Model(inputs = inputs, outputs = predlayer, name = 'vgg-16_feature_extractor_pred')

return model, model_pred
```