我有个任务是这样的:
# compute estimates from input
net_estimate = my_model(inputs)
# use this estimate to compute a target
target_estimate = lots_of_computations(net_estimate)
# compute loss
loss = compute_loss(net_estimate, target_estimate)
(对于某些上下文,这是一个强化学习任务,其结果状态和奖励取决于网络采取的操作。)
问题是我不想(实际上不能)计算
lots_of_computations
. 理想情况下,我希望暂停并恢复渐变录制
with tf.GradientTape() as tape:
net_estimate = my_model(inputs)
# target_estimate should be considered a constant
target_estimate = lots_of_computations(net_estimate)
with tape.resume():
loss = compute_loss(net_estimate, target_estimate)
tape.gradient(loss, my_model.params)
但是
GradientTape
似乎没有提供类似的东西。有没有办法在急切的模式下实现这一点?我现在的解决办法是
net_estimate
两次,但这显然是次优的。