代码之家  ›  专栏  ›  技术社区  ›  ntlarry

命名TensorFlow/Keras检查点

  •  0
  • ntlarry  · 技术社区  · 5 年前

    我遵循关于TensorFlow的“使用RNN生成文本”教程( link

    # Directory where the checkpoints will be saved
    checkpoint_dir = './training_checkpoints'
    # Name of the checkpoint files
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch+10}")
    
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True)
    
    EPOCHS = 10
    

    与TensorFlow网站的原始代码的唯一区别是,我修改了原始代码行

    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
    

    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch+10}")
    

    但是,它不起作用。错误如下:

    KeyError: 'epoch+10'
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "project/RNN_text_generator_finetune.py", line 102, in <module>
        history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
        return method(self, *args, **kwargs)
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1137, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py", line 412, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py", line 1249, in on_epoch_end
        self._save_model(epoch=epoch, logs=logs)
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py", line 1282, in _save_model
        filepath = self._get_file_path(epoch, logs)
      File "/opt/miniconda3/envs/newest11142020/lib/python3.8/site-packages/tensorflow/python/keras/callbacks.py", line 1332, in _get_file_path
        raise KeyError('Failed to format this callback filepath: "{}". '
    KeyError: 'Failed to format this callback filepath: "./training_checkpoints/ckpt_{epoch+10}". Reason: \'epoch+10\''
    

    有没有办法重命名代码中的检查点?

    1 回复  |  直到 5 年前
        1
  •  1
  •   Innat    5 年前

    恢复训练时可以设置如下:

    model.fit(..., 
         initial_epoch=epoch,
         ..)
    

    在这里, initial_epoch 是一个整数。开始训练的时间,这对于恢复之前的训练很有用)。假设你训练了一个时代的模特 10 停止训练。所以,当恢复训练时,设置 起始纪元 10 Src ,有洞察力 discussin .

    推荐文章