tflite_convert 具有自定义损失函数的 Keras h5 模型会导致 ValueError,即使我将其添加到 Keras 损失导入中

数据挖掘 Python 喀拉斯 张量流
2022-02-17 19:11:56

我写了一个 SRGAN 实现。在 Python 程序的入口点类中,我声明了一个使用 VGG19 模型返回均方的函数:

# <!--- COST FUNCTION --->
def build_vgg19_loss_network(ground_truth_image, predicted_image):
    loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
    return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))

import keras.losses
keras.losses.build_vgg19_loss_network = build_vgg19_loss_network
# <!--- /COST FUNCTION --->

Vgg19Loss下面进一步显示的类)

如您所见,我在 import 中添加了这个自定义损失函数keras.losses为什么?因为我认为它可以解决以下问题......:当我执行命令tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5时,Python解释器会引发此错误:

raise ValueError('Unknown ' + printable_module_name + ':' + object_name) ValueError: Unknown loss function:build_vgg19_loss_network

然而,它并没有解决问题。还有其他可行的解决方案吗?

这是Vgg19Loss课程:

from keras import Model
from keras.applications import VGG19


class Vgg19Loss:
    def __init__(self):
        pass

    @staticmethod
    def define_loss_model(high_resolution_shape):
        model_vgg19 = VGG19(False, 'imagenet', input_shape=high_resolution_shape)
        model_vgg19.trainable = False
        for l in model_vgg19.layers:
            l.trainable = False
        loss_model = Model(model_vgg19.input, model_vgg19.get_layer('block5_conv4').output)
        loss_model.trainable = False
        return loss_model
1个回答

我尝试了您通过以下方式发布的代码:

from keras import Model
from keras.applications import VGG19
import keras.backend as K


class Vgg19Loss:
    def __init__(self):
        pass

    @staticmethod
    def define_loss_model(high_resolution_shape):
        model_vgg19 = VGG19(False, 'imagenet', input_shape=high_resolution_shape)
        model_vgg19.trainable = False
        for l in model_vgg19.layers:
            l.trainable = False
        loss_model = Model(model_vgg19.input, model_vgg19.get_layer('block5_conv4').output)
        loss_model.trainable = False
        return loss_model


def build_vgg19_loss_network(ground_truth_image, predicted_image):
    loss_model = Vgg19Loss.define_loss_model(high_resolution_shape) # where is this variable coming from?
    return K.mean(K.square(loss_model(ground_truth_image) - loss_model(predicted_image)))

import keras.losses
keras.losses.build_vgg19_loss_network = build_vgg19_loss_network

print(keras.losses.build_vgg19_loss_network)  # <function build_vgg19_loss_network at 0x7f05e8e1cbf8>

我没有收到错误消息,并且该功能已分配给损失模块。这意味着有问题的行可能不是您发布的内容的一部分。很高兴知道哪一行代码引发了您引用的错误。

但是,我不确定high_resolution_shape函数中第 22 行的这个参数来自哪里build_vgg_19_network如果这是一个全局常量,它应该用下划线分隔的所有大写字母来写,以防止混淆。如果没有定义它迟早会抛出一个 NameError 。

如果我keras.losses.build_vgg19_loss_network(None, None)在运行上面的代码后执行,我会收到以下错误消息:

NameError:未定义名称“high_resolution_shape”

编辑:如果此错误仅在 TFLite 转换期间发生,则因为 tensorflow 1.x 中的 TFLiteConverter 尚不支持自定义对象。但是,在 tensorflow Github 存储库中有一个提交,它解决了这个问题并添加了对自定义对象的支持(另请参阅相关的拉取请求)。它应该是官方tensorflow v2.0.0-beta1的一部分。