在 TensorFlow Object Detection API 中提取特征图

数据挖掘 深度学习 张量流 美国有线电视新闻网 物体检测
2022-02-14 11:41:56

我想知道如何提取在 tensorflow 对象检测 API 上训练的移动网络的特征图。我想采用该特征图来提供另一个分类器。谢谢!

1个回答
import tensorflow as tf

# Sample frozen model
MODEL = "frozen_inference_graph.pb"

# An existing operation from the frozen model
OP_NAME = "WeightSharedConvolutionalBoxPredictor/BoxPredictionTower/conv2d_2/BatchNorm/feature_0/beta/read/_360__cf__363"

# Load the graph from the frozen model
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(MODEL, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')


# print all operations
for op in detection_graph.get_operations():
    print(op.name)

# print tensor ( without :0 you will get the operation itself )
print(detection_graph.get_tensor_by_name("{}:0".format(OP_NAME)))