我想知道如何提取在 tensorflow 对象检测 API 上训练的移动网络的特征图。我想采用该特征图来提供另一个分类器。谢谢!
在 TensorFlow Object Detection API 中提取特征图
数据挖掘
深度学习
张量流
美国有线电视新闻网
物体检测
2022-02-14 11:41:56
1个回答
import tensorflow as tf
# Sample frozen model
MODEL = "frozen_inference_graph.pb"
# An existing operation from the frozen model
OP_NAME = "WeightSharedConvolutionalBoxPredictor/BoxPredictionTower/conv2d_2/BatchNorm/feature_0/beta/read/_360__cf__363"
# Load the graph from the frozen model
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(MODEL, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# print all operations
for op in detection_graph.get_operations():
print(op.name)
# print tensor ( without :0 you will get the operation itself )
print(detection_graph.get_tensor_by_name("{}:0".format(OP_NAME)))
其它你可能感兴趣的问题