我在 Keras 中训练了一个特征提取器,并将权重保存为 h5 文件。现在我想将相同的权重加载到在 PyTorch 中创建和初始化的相同模型中,以进行性能比较。有什么方法可以将 h5 文件转换为 pth 文件,以便将其加载到 PyTorch 模型中?
如何将 Keras h5 转换为 PyTorch pth 格式?
数据挖掘
深度学习
喀拉斯
火炬
2022-02-11 03:16:16
1个回答
您必须小心 TF 和 PyTorch 的版本(因为某些命令可能不同)。
基本上你必须:
1 - 了解 Keras 上的层和激活结构:
您可以通过以下方式获取图层信息:
model_keras.summary()
如果您无法获取有关激活函数的信息,请尝试:
for layer in model_keras.layers:
print(layer.output)
2 - 在 PyTorch 上构建与 Keras 具有相同层结构(和激活)的模型
3 - 最近创建的 PyTorch 模型(比方说model_pyt)与您在 Keras 上的模型具有不同的权重和偏差,因此您必须将这些权重和偏差从 Keras 模型复制到 PyTorch 模型:
请注意,PYTORCH 权重与 KERAS 权重相关
然后,例如,复制权重就像:
model_pyt.layer1.weight.data = torch.tensor(model_keras.layers[0].get_weights()[0].T)
对于偏见:
model_pyt.layer1.bias.data = torch.tensor(model_keras.layers[0].get_weights()[1])
对所有图层重复此操作。
然后,您的 PyTorch 模型具有与 Keras 模型相同的架构和权重,并且可能以相同的方式运行。
其它你可能感兴趣的问题