如何运行保存的 TensorFlow 模型?(视频预测模型)
数据挖掘
深度学习
张量流
2022-02-13 12:05:46
2个回答
TensorFlow saver 用于保存特定模型在某个给定点的权重。当你想使用一个训练好的模型时,你必须首先定义模型的架构(它应该类似于用于保存权重的架构),然后你可以使用相同的“saver”类来恢复权重:
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "../my_saved_model.ckpt")
关于你最初的问题。我认为,如果您刚开始使用深度学习和 TensorFlow,那么这是错误的起点,您应该首先了解 TensorFlow 的一般工作原理,将其应用于图像分类等更简单的任务(从 MNIST 开始)。
据我了解,您需要使用“construct_model”函数并将初始图像序列(视频)和一些动作张量传递给它,它应该输出预测帧:
def construct_model(images,
actions=None,
states=None,
iter_num=-1.0,
k=-1,
use_state=True,
num_masks=10,
stp=False,
cdna=True,
dna=False,
context_frames=2):
"""Build convolutional lstm video predictor using STP, CDNA, or DNA.
Args:
images: tensor of ground truth image sequences
actions: tensor of action sequences
states: tensor of ground truth state sequences
iter_num: tensor of the current training iteration (for sched. sampling)
k: constant used for scheduled sampling. -1 to feed in own prediction.
use_state: True to include state and action in prediction
num_masks: the number of different pixel motion predictions (and
the number of masks for each of those predictions)
stp: True to use Spatial Transformer Predictor (STP)
cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
dna: True to use Dynamic Neural Advection (DNA)
context_frames: number of ground truth frames to pass in before
feeding in own predictions
Returns:
gen_images: predicted future image frames
gen_states: predicted future states
视频中的预测太容易了。使用 Opencv for video 将视频读入图像并将它们存储在一个 numpy 数组中。稍后您将加载您的 tf 模型,然后进行一些预测。
"""
Sections of this code were taken from:
https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
"""
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
import cv2
# Path to frozen detection graph. This is the actual model that is used
# for the object detection.
PATH_TO_CKPT = '../freezed_pb5_optimized/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('../../../training', 'object-detection.pbtxt')
NUM_CLASSES = 1
sys.path.append("..")
def detect_in_video():
# VideoWriter is the responsible of creating a copy of the video
# used for the detections but with the detections overlays. Keep in
# mind the frame size has to be the same as original video.
out = cv2.VideoWriter('pikachu_detection_1v3.avi', cv2.VideoWriter_fourcc(
'M', 'J', 'P', 'G'), 10, (1280, 720))
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object
# was detected.
detection_boxes = detection_graph.get_tensor_by_name(
'detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class
# label.
detection_scores = detection_graph.get_tensor_by_name(
'detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name(
'detection_classes:0')
num_detections = detection_graph.get_tensor_by_name(
'num_detections:0')
cap = cv2.VideoCapture('PikachuKetchup.mp4')
while(cap.isOpened()):
# Read the frame
ret, frame = cap.read()
# Recolor the frame. By default, OpenCV uses BGR color space.
# This short blog post explains this better:
# https://www.learnopencv.com/why-does-opencv-use-bgr-color-format/
color_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_np_expanded = np.expand_dims(color_frame, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores,
detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
# note: perform the detections using a higher threshold
vis_util.visualize_boxes_and_labels_on_image_array(
color_frame,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8,
min_score_thresh=.20)
cv2.imshow('frame', color_frame)
output_rgb = cv2.cvtColor(color_frame, cv2.COLOR_RGB2BGR)
out.write(output_rgb)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
out.release()
cap.release()
cv2.destroyAllWindows()
def main():
detect_in_video()
if __name__ == '__main__':
main()
其它你可能感兴趣的问题