我想用一个 tf.estimator.Estimator tf.data 应用程序编程接口。
tf.estimator.Estimator
tf.data
我有这样的想法:
def model_fn(features, labels, params, mode): # Defines model's ops. # Initializes with tf.train.Scaffold. # Returns an tf.estimator.EstimatorSpec. def input_fn(): dataset = tf.data.TextLineDataset("test.txt") # map, shuffle, padded_batch, etc. iterator = dataset.make_initializable_iterator() return iterator.get_next() estimator = tf.estimator.Estimator(model_fn) estimator.train(input_fn)
因为我不能使用 make_one_shot_iterator 对于我的用例,我的问题是 input_fn model_fn tf.train.Scaffold 初始化本地操作)。
make_one_shot_iterator
input_fn
model_fn
tf.train.Scaffold
input_fn = iterator.get_next
从TensorFlow 1.5开始,可以 input_fn 返回a tf.data.Dataset ,例如:
tf.data.Dataset
def input_fn(): dataset = tf.data.TextLineDataset("test.txt") # map, shuffle, padded_batch, etc. return dataset
看见 c294fcfd .
对于以前的版本,可以在 tf.GraphKeys.TABLE_INITIALIZERS
tf.GraphKeys.TABLE_INITIALIZERS
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)