我正在尝试将此代码“使用原型网络的 Omniglot 字符集分类”迁移到 Tensorflow 2.1.0 和 Keras 2.3.1。
我的问题是关于如何在训练数据和验证数据之间使用欧几里得距离。看看这段代码:
def convolution_block(inputs, out_channels, name='conv'):
conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')
conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
conv = tf.nn.relu(conv)
conv = tf.contrib.layers.max_pool2d(conv, 2)
return conv
def get_embeddings(support_set, h_dim, z_dim, reuse=False):
net = convolution_block(support_set, h_dim)
net = convolution_block(net, h_dim)
net = convolution_block(net, h_dim)
net = convolution_block(net, z_dim)
net = tf.contrib.layers.flatten(net)
return net
support_set_embeddings = get_embeddings(tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim)
embedding_dimension = tf.shape(support_set_embeddings)[-1]
class_prototype = tf.reduce_mean(tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1)
query_set_embeddings = get_embeddings(tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True)
def euclidean_distance(a, b):
N, D = tf.shape(a)[0], tf.shape(a)[1]
M = tf.shape(b)[0]
a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))
b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))
return tf.reduce_mean(tf.square(a - b), axis=2)
distance = euclidean_distance(query_set_embeddings,class_prototype)
predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1])
loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1]))
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y)))
train = tf.train.AdamOptimizer().minimize(loss)
如果我正确理解了所有内容,它会从support_set(又名训练数据)中获取嵌入,并从query_set(又名验证数据)中获取嵌入。计算 中所有嵌入的平均值support_set,因为它们都来自同一类。然后,它使用这个平均值来计算嵌入query_set与这个平均值(又名class_prototype)之间的距离。
所以,如果我想使用 VGG16 作为get_embeddings函数。换句话说,我将使用它来获取support_setand的嵌入query_set:
def vgg16_feature_extractor(input_size = (200,200,1)):
inputs = Input(input_size, name = 'input')
conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', name ='conv1_1')(inputs)
conv1 = Conv2D(64, (3, 3), activation = 'relu', padding = 'same', name ='conv1_2')(conv1)
pool1 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_1')(conv1)
conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', name ='conv2_1')(pool1)
conv2 = Conv2D(128, (3, 3), activation = 'relu', padding = 'same', name ='conv2_2')(conv2)
pool2 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_2')(conv2)
conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_1')(pool2)
conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_2')(conv3)
conv3 = Conv2D(256, (3, 3), activation = 'relu', padding = 'same', name ='conv3_3')(conv3)
pool3 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_3')(conv3)
conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_1')(pool3)
conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_2')(conv4)
conv4 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv4_3')(conv4)
pool4 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_4')(conv4)
conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_1')(pool4)
conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_2')(conv5)
conv5 = Conv2D(512, (3, 3), activation = 'relu', padding = 'same', name ='conv5_3')(conv5)
pool5 = MaxPooling2D(pool_size = (2,2), strides = (2,2), name = 'pool_5')(conv5)
model = Model(inputs = inputs, outputs = pool5, name = 'vgg-16_feature_extractor')
return model
然后在train.py:
model = vgg16_feature_extractor(input_size = (200,200,1))
model.compile(optimizer=opt, loss=my_own_loss_function, metrics=['accuracy'])
model.fit(...)
我不知道如何实现my_own_loss_function,因为这个函数只有两个参数y_true, y_pred,并且必须使用嵌入和嵌入y_pred之间的欧几里德距离来计算。support_setquery_set
我必须如何实现my_own_loss_function才能根据需要使用它?
也许,y_true是来自的嵌入,support_set并且y_pred是来自的嵌入query_set。