我有一个关于 TensorFlow 的保护程序功能的一般性问题。
saver 类允许我们通过以下方式保存会话:
saver.save(sess, "checkpoints.ckpt")
并允许我们恢复会话:
saver.restore(sess, tf.train.latest_checkpoint("checkpoints.ckpt"))
在 TensorFlow 文档中,有一个示例代码(添加了 epoch 循环和恢复):
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
saver.restore(sess, tf.train.latest_checkpoint("checkpoints.ckpt"))
for epoch in xrange(25):
for step in xrange(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
问题是,如果我们在 处停止训练循环epoch=15并再次执行,那么如果我们epoch=0再次开始,但模型被训练到epoch=15.
有没有办法从 恢复epoch=15?