代码之家  ›  专栏  ›  技术社区  ›  Mark

避免tf.data.Dataset数据集。从张量切片与估计器api

  •  2
  • Mark  · 技术社区  · 6 年前

    我在试着找出使用 dataset estimator 应用程序编程接口。我在网上看到的一切都是这样的:

    def train_input_fn():
       dataset = tf.data.Dataset.from_tensor_slices((features, labels))
       return dataset
    

    然后可以传递给估计器的序列函数:

     classifier.train(
        input_fn=train_input_fn,
        #...
     )
    

    但是 dataset guide

    上面的代码片段将在TensorFlow图中嵌入特性和标签数组tf.常数()操作。这对于小数据集很有效,但会浪费内存,因为数组的内容将被多次复制,并且对于数据集可能会达到2GB的限制tf.GraphDef文件协议缓冲区。

    然后描述了一个方法,该方法涉及定义占位符,然后用 feed_dict :

    features_placeholder = tf.placeholder(features.dtype, features.shape)
    labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
    
    dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
    
    sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                              labels_placeholder: labels})
    

    api,您没有手动运行会话。那么你如何使用 数据集 from_tensor_slices() ?

    2 回复  |  直到 6 年前
        1
  •  4
  •   Olivier Dehaene    6 年前

    要使用可初始化或可重新初始化的迭代器,必须创建继承自的类tf.train.SessionRunHook会话,在培训和评估步骤中可以多次访问会话。

    然后可以使用这个新类来初始化迭代器,这通常是在经典设置中执行的。您只需将这个新创建的钩子传递给training/evaluation函数或正确的train规范。

    class IteratorInitializerHook(tf.train.SessionRunHook):
        def __init__(self):
            super(IteratorInitializerHook, self).__init__()
            self.iterator_initializer_func = None # Will be set in the input_fn
    
        def after_create_session(self, session, coord):
            # Initialize the iterator with the data feed_dict
            self.iterator_initializer_func(session) 
    
    
    def get_inputs(X, y):
        iterator_initializer_hook = IteratorInitializerHook()
    
        def input_fn():
            X_pl = tf.placeholder(X.dtype, X.shape)
            y_pl = tf.placeholder(y.dtype, y.shape)
    
            dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
            dataset = ...
            ...
    
            iterator = dataset.make_initializable_iterator()
            next_example, next_label = iterator.get_next()
    
    
            iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                        feed_dict={X_pl: X, y_pl: y})
    
            return next_example, next_label
    
        return input_fn, iterator_initializer_hook
    
    ...
    
    train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
    test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)
    
    ...
    
    estimator.train(input_fn=train_input_fn,
                    hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
    estimator.evaluate(input_fn=test_input_fn,
                       hooks=[test_iterator_initializer_hook])