听起来你可以把任何你想多次运行的操作放在
tf.while_loop
. 如果操作是独立的,则可能需要设置
parallel_iterations
到
1
或者(更好的)使用控件依赖项对优化器调用进行排序。例如:
import tensorflow as tf
with tf.Graph().as_default():
opt = tf.train.AdamOptimizer(0.1)
var = tf.get_variable(name="var", shape=[], use_resource=True)
def _cond(i, _):
return tf.less(i, 20)
def _body(i, sequencer):
with tf.control_dependencies([sequencer]):
loss = .5 * (var - 10.) ** 2
print_op = tf.Print(loss, ["Evaluating loss", i, loss])
with tf.control_dependencies([print_op]):
train_op = opt.minimize(loss)
with tf.control_dependencies([train_op]):
next_sequencer = tf.ones([])
return i + 1, next_sequencer
initial_value = var.read_value()
with tf.control_dependencies([initial_value]):
_, sequencer = tf.while_loop(cond=_cond, body=_body, loop_vars=[0, 1.])
with tf.control_dependencies([sequencer]):
final_value = var.read_value()
init_op = tf.global_variables_initializer()
with tf.Session() as session:
session.run([init_op])
print(session.run([initial_value, final_value]))
打印:
2017-12-21 11:40:35.920035: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][0][46.3987083]
2017-12-21 11:40:35.920317: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][1][45.4404]
2017-12-21 11:40:35.920534: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][2][44.4923515]
2017-12-21 11:40:35.920715: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][3][43.55476]
2017-12-21 11:40:35.920905: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][4][42.6277695]
2017-12-21 11:40:35.921084: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][5][41.711544]
2017-12-21 11:40:35.921273: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][6][40.8062363]
2017-12-21 11:40:35.921426: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][7][39.9120026]
2017-12-21 11:40:35.921578: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][8][39.028965]
2017-12-21 11:40:35.921732: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][9][38.1572723]
2017-12-21 11:40:35.921888: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][10][37.2970314]
2017-12-21 11:40:35.922053: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][11][36.4483566]
2017-12-21 11:40:35.922187: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][12][35.6113625]
2017-12-21 11:40:35.922327: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][13][34.7861366]
2017-12-21 11:40:35.922472: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][14][33.9727631]
2017-12-21 11:40:35.922613: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][15][33.1713257]
2017-12-21 11:40:35.922777: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][16][32.3818779]
2017-12-21 11:40:35.922942: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][17][31.6044941]
2017-12-21 11:40:35.923115: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][18][30.8392067]
2017-12-21 11:40:35.923253: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][19][30.0860634]
[0.36685812, 2.3390481]