我最近从 PyTorch 切换到 TF(1 和 2),我正在尝试使用它获得良好的工作流程。
我想做的简单的事情如下:
- 从TF1 zoo或TF2 zoo加载完整的预训练对象检测模型
- 用于
model.summary()检查加载模型的网络架构。 - 使用预训练的加载模型进行推理。
- 修改(例如重塑、删除、添加)加载模型的层和权重。
- 重新训练修改后的加载模型。
我知道 TF 具有图形和权重的概念,而 PyTorch 仅具有包含所有内容的模型。尽管如此,我还是找不到一种简单且最佳的方法来加载预训练模型,并且互联网上充斥着针对不同 tf 版本的不同答案。
我真的很困惑,因为要实现上述几点,当我从TF1 zoo(或 TF2 zoo)下载预训练模型时,我有很多不同的文件可用。
以这个为例,TF1 动物园列表中的第一个。我有一个saved_model文件夹saved_model.pb和variables(空)文件夹,frozen_inference_graph.pb文件model.ckpt,pipeline.config在某些情况下还有一个event文件。所有这些不同的文件真的需要对图形结构和权重进行编码吗?我是否遗漏了什么或者这只是比必要的更复杂?此外,如果您从TF2 zoo下载模型,文件/文件夹结构会有所不同(见下图)
我试过的
import tensorflow as tf #(v2.4)
def load_pretrained_model(self, saved_model_sub_folder,
mode):
# 1. this only load an AutoTrackable object that can be use for inference but no graph
if mode == '.pb':
model_dir = str(TRAINED_MODEL_DIR) + saved_model_sub_folder
model_dir = pathlib.Path(model_dir) / "saved_model"
model = tf.saved_model.load(str(model_dir), None, '.')
detection_model = model.signatures['serving_default']
# 2. this returns None
elif mode == '.graph':
def load_graph(frozen_graph_filename):
with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
detection_model = tf.compat.v1.import_graph_def(load_graph(frozen_graph_filename))
else:
detection_model = None
return detection_model
Tl;博士
有人可以回答上面关于如何在 python3 中加载完整的(图形、权重、一切......)可定制的 tensorflow1 或 tensorflow2 模型的一些要点(1 到 5)吗?
