如何运行保存的 TensorFlow 模型?(视频预测模型)

数据挖掘 深度学习 张量流
2022-02-13 12:05:46

我正在读这篇论文。该代码在GitHub中可用。README.md文件中,他们提到了如何训练模型

python prediction_train.py

带有可选参数。
谁能解释我如何使用这个模型来预测视频序列?我是深度学习和张量流的新手。所以,我无法正确理解代码。我当前的任务是运行代码并查看输出(即此模型预测的视频)。

我所能理解的是,它使用 tensirflow saver 来保存检查点。我猜这些检查点是经过几个时期(在本例中为 2000 年)之后的中间训练模型。如何使用这些模型来预测视频的下一帧?

任何帮助是极大的赞赏 :)

2个回答

TensorFlow saver 用于保存特定模型在某个给定点的权重。当你想使用一个训练好的模型时,你必须首先定义模型的架构(它应该类似于用于保存权重的架构),然后你可以使用相同的“saver”类来恢复权重:

with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "../my_saved_model.ckpt")

关于你最初的问题。我认为,如果您刚开始使用深度学习和 TensorFlow,那么这是错误的起点,您应该首先了解 TensorFlow 的一般工作原理,将其应用于图像分类等更简单的任务(从 MNIST 开始)。

据我了解,您需要使用“construct_model”函数并将初始图像序列(视频)和一些动作张量传递给它,它应该输出预测帧:

def construct_model(images,
                actions=None,
                states=None,
                iter_num=-1.0,
                k=-1,
                use_state=True,
                num_masks=10,
                stp=False,
                cdna=True,
                dna=False,
                context_frames=2):
"""Build convolutional lstm video predictor using STP, CDNA, or DNA.
  Args:
    images: tensor of ground truth image sequences
    actions: tensor of action sequences
    states: tensor of ground truth state sequences
    iter_num: tensor of the current training iteration (for sched. sampling)
    k: constant used for scheduled sampling. -1 to feed in own prediction.
    use_state: True to include state and action in prediction
    num_masks: the number of different pixel motion predictions (and
           the number of masks for each of those predictions)
    stp: True to use Spatial Transformer Predictor (STP)
    cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
    dna: True to use Dynamic Neural Advection (DNA)
    context_frames: number of ground truth frames to pass in before
                feeding in own predictions
  Returns:
    gen_images: predicted future image frames
    gen_states: predicted future states

视频中的预测太容易了。使用 Opencv for video 将视频读入图像并将它们存储在一个 numpy 数组中。稍后您将加载您的 tf 模型,然后进行一些预测。

"""
Sections of this code were taken from:
https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
"""
import numpy as np

import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image

from utils import label_map_util

from utils import visualization_utils as vis_util

import cv2

# Path to frozen detection graph. This is the actual model that is used
# for the object detection.
PATH_TO_CKPT = '../freezed_pb5_optimized/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('../../../training', 'object-detection.pbtxt')

NUM_CLASSES = 1

sys.path.append("..")


def detect_in_video():

    # VideoWriter is the responsible of creating a copy of the video
    # used for the detections but with the detections overlays. Keep in
    # mind the frame size has to be the same as original video.
    out = cv2.VideoWriter('pikachu_detection_1v3.avi', cv2.VideoWriter_fourcc(
        'M', 'J', 'P', 'G'), 10, (1280, 720))

    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')

    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(
        label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)

    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            # Definite input and output Tensors for detection_graph
            image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object
            # was detected.
            detection_boxes = detection_graph.get_tensor_by_name(
                'detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class
            # label.
            detection_scores = detection_graph.get_tensor_by_name(
                'detection_scores:0')
            detection_classes = detection_graph.get_tensor_by_name(
                'detection_classes:0')
            num_detections = detection_graph.get_tensor_by_name(
                'num_detections:0')
            cap = cv2.VideoCapture('PikachuKetchup.mp4')

            while(cap.isOpened()):
                # Read the frame
                ret, frame = cap.read()

                # Recolor the frame. By default, OpenCV uses BGR color space.
                # This short blog post explains this better:
                # https://www.learnopencv.com/why-does-opencv-use-bgr-color-format/
                color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                image_np_expanded = np.expand_dims(color_frame, axis=0)

                # Actual detection.
                (boxes, scores, classes, num) = sess.run(
                    [detection_boxes, detection_scores,
                        detection_classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})

                # Visualization of the results of a detection.
                # note: perform the detections using a higher threshold
                vis_util.visualize_boxes_and_labels_on_image_array(
                    color_frame,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8,
                    min_score_thresh=.20)

                cv2.imshow('frame', color_frame)
                output_rgb = cv2.cvtColor(color_frame, cv2.COLOR_RGB2BGR)
                out.write(output_rgb)

                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

            out.release()
            cap.release()
            cv2.destroyAllWindows()


def main():
    detect_in_video()


if __name__ == '__main__':
    main()