代码之家  ›  专栏  ›  技术社区  ›  haxtar

不从目录中提取而还原Tensorflow模型

  •  7
  • haxtar  · 技术社区  · 6 年前

    Saver

    saver.save(sess, checkpoint_prefix, global_step=step)

    saver.restore(sess, checkpoint_file)

    .ckpt 将模型的文件复制到指定路径。因为我正在进行多个实验,所以保存这些模型的空间有限。

    我想知道是否有办法保存这些模型而不保存指定目录中的内容。

    在我看来 save_path 中的参数 tf.train.Saver.restore()

    如有任何见解,将不胜感激。

    谢谢

    1 回复  |  直到 6 年前
        1
  •  1
  •   McAngus    6 年前

    1000 tf.Session 打电话 sess

    x = tf.placeholder(...)
    loss, train_step = model(x)
    for i in range(num_step):
        input_x = get_train_data(i)
        sess.run(train_step, feed_dict={x: input_x})
        if i % 1000 == 0 and i != 0:
            eval_loss = 0
            for j in range(num_eval):
                input_x = get_eval_data(j)
                eval_loss += sess.run(loss, feed_dict={x: input_x})
            print(eval_loss/num_eval)
    

    如果你用的是 tf.data 对于您的输入,您只需创建一个 tf.cond 要选择要使用的输入:

    is_training = tf.placeholder(tf.bool)
    next_element = tf.cond(is_training,
                            lambda: get_next_train(),
                            lambda: get_next_eval())
    

    get_next_train get_next_eval

    这样,如果不想,就不必将任何内容保存到光盘。