如何将 Keras h5 转换为 PyTorch pth 格式?

数据挖掘 深度学习 喀拉斯 火炬
2022-02-11 03:16:16

我在 Keras 中训练了一个特征提取器,并将权重保存为 h5 文件。现在我想将相同的权重加载到在 PyTorch 中创建和初始化的相同模型中,以进行性能比较。有什么方法可以将 h5 文件转换为 pth 文件,以便将其加载到 PyTorch 模型中?

1个回答

您必须小心 TF 和 PyTorch 的版本(因为某些命令可能不同)。

基本上你必须:

1 - 了解 Keras 上的层和激活结构:

您可以通过以下方式获取图层信息:

model_keras.summary()

如果您无法获取有关激活函数的信息,请尝试:

for layer in model_keras.layers:
    print(layer.output)

2 - 在 PyTorch 上构建与 Keras 具有相同层结构(和激活)的模型

3 - 最近创建的 PyTorch 模型(比方说model_pyt)与您在 Keras 上的模型具有不同的权重和偏差,因此您必须将这些权重和偏差从 Keras 模型复制到 PyTorch 模型

请注意,PYTORCH 权重与 KERAS 权重相关

然后,例如,复制权重就像:

model_pyt.layer1.weight.data = torch.tensor(model_keras.layers[0].get_weights()[0].T)

对于偏见:

model_pyt.layer1.bias.data = torch.tensor(model_keras.layers[0].get_weights()[1])

对所有图层重复此操作。

然后,您的 PyTorch 模型具有与 Keras 模型相同的架构和权重,并且可能以相同的方式运行。