代码之家  ›  专栏  ›  技术社区  ›  P-Gn

如何在紧急执行时暂停和恢复渐变录制?

  •  0
  • P-Gn  · 技术社区  · 7 年前

    我有个任务是这样的:

    # compute estimates from input
    net_estimate = my_model(inputs)
    # use this estimate to compute a target
    target_estimate = lots_of_computations(net_estimate)
    # compute loss
    loss = compute_loss(net_estimate, target_estimate)
    

    (对于某些上下文,这是一个强化学习任务,其结果状态和奖励取决于网络采取的操作。)

    问题是我不想(实际上不能)计算 lots_of_computations . 理想情况下,我希望暂停并恢复渐变录制

    with tf.GradientTape() as tape:
      net_estimate = my_model(inputs)
    # target_estimate should be considered a constant
    target_estimate = lots_of_computations(net_estimate)
    with tape.resume():
      loss = compute_loss(net_estimate, target_estimate)
    tape.gradient(loss, my_model.params)
    

    但是 GradientTape 似乎没有提供类似的东西。有没有办法在急切的模式下实现这一点?我现在的解决办法是 net_estimate 两次,但这显然是次优的。

    1 回复  |  直到 7 年前
        1
  •  1
  •   ash    7 年前

    tf.GradientTape.stop_recording 可能就是你要找的。

    它是最近(在tensorflow 1.8之后)引入的,所以现在您需要使用tensorflow1.9.0的候选版本。

    希望能有所帮助。