我写了一个 SRGAN 实现。在 Python 程序的入口点类中,我声明了一个使用 VGG19 模型返回均方的函数:
# <!--- COST FUNCTION --->
def build_vgg19_loss_network(ground_truth_image, predicted_image):
loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))
import keras.losses
keras.losses.build_vgg19_loss_network = build_vgg19_loss_network
# <!--- /COST FUNCTION --->
(Vgg19Loss下面进一步显示的类)
如您所见,我在 import 中添加了这个自定义损失函数keras.losses。为什么?因为我认为它可以解决以下问题......:当我执行命令tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5时,Python解释器会引发此错误:
raise ValueError('Unknown ' + printable_module_name + ':' + object_name) ValueError: Unknown loss function:build_vgg19_loss_network
然而,它并没有解决问题。还有其他可行的解决方案吗?
这是Vgg19Loss课程:
from keras import Model
from keras.applications import VGG19
class Vgg19Loss:
def __init__(self):
pass
@staticmethod
def define_loss_model(high_resolution_shape):
model_vgg19 = VGG19(False, 'imagenet', input_shape=high_resolution_shape)
model_vgg19.trainable = False
for l in model_vgg19.layers:
l.trainable = False
loss_model = Model(model_vgg19.input, model_vgg19.get_layer('block5_conv4').output)
loss_model.trainable = False
return loss_model