TensorFlow - 在一个时代的中间恢复训练?

数据挖掘 张量流
2022-03-07 12:34:40

我有一个关于 TensorFlow 的保护程序功能的一般性问题。

saver 类允许我们通过以下方式保存会话:

saver.save(sess, "checkpoints.ckpt")

并允许我们恢复会话:

saver.restore(sess, tf.train.latest_checkpoint("checkpoints.ckpt"))

在 TensorFlow 文档中,有一个示例代码(添加了 epoch 循环和恢复):

# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
saver.restore(sess, tf.train.latest_checkpoint("checkpoints.ckpt"))
for epoch in xrange(25):
    for step in xrange(1000000):
        sess.run(..training_op..)
        if step % 1000 == 0:
            # Append the step number to the checkpoint name:
            saver.save(sess, 'my-model', global_step=step)

问题是,如果我们在 处停止训练循环epoch=15并再次执行,那么如果我们epoch=0再次开始,但模型被训练到epoch=15.

有没有办法从 恢复epoch=15

1个回答

网络不存储关于训练数据的训练进度——这不是它的状态的一部分,因为在任何时候你都可以决定改变什么数据集来提供它。你也许可以修改它,让它知道训练数据和进度,存储在某个张量的某个地方,但这很不寻常。因此,为了做到这一点,您需要保存和使用TensorFlow 框架之外的其他数据。

可能最简单的做法是将纪元编号添加到文件名中。您已经在 epoch 中添加了当前步骤,因此只需添加 epoch 相乘:

saver.save(sess, 'my-model', global_step=epoch*1000000+step)

xrange加载文件时,您可以解析文件名以发现您所处的时代和步骤,并将它们用作函数的起点。为了更容易从任何给定的检查点重新启动,您可以使用argparse允许您的脚本采用您要使用的检查点文件的名称。

简而言之,它可能看起来像这样:

# Near top of script
import argparse
import re

# Before main logic
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint')
args = parser.parse_args()

start_epoch = 0
start_step = 0
if args.checkpoint:
    saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint))
    found_num = re.search(r'\d+', args.checkpoint)
    if found_num:
        checkpoint_id = int(found_num.group(0))
        start_epoch = checkpoint_id // 1000000
        start_step = checkpoint_id % 1000000

# Change to xrange:
for epoch in xrange(start_epoch, 25):
    for step in xrange(start_step, 1000000):
        sess.run(..training_op..) # etc

    # At end of epoch loop, you need to re-set steps:
    start_step = 0

您可能希望减少您正在创建的检查点的数量 - 就目前而言,您的代码将生成 25,000 个检查点文件。

另一种选择是使用单个检查点文件,并保存和恢复一个简单的 Python pickle,其中dict包含您创建检查点时的状态,名称相似。