代码之家  ›  专栏  ›  技术社区  ›  nbf77

将Tensorflow输入管道更改为“数据集”-错误

  •  0
  • nbf77  · 技术社区  · 7 年前

    到目前为止,我在tensorflow中使用的管道如下:

    queue_filenames = tf.train.string_input_producer(data)
    reader = tf.FixedLengthRecordReader(record_bytes=4*4)
    
    class Record(object):
        pass
    result = Record()
    result.ley, value = reader.read(queue_filenames)
    record = tf.decode_raw(value, tf.float32)
    image = tf.reshape(tf.strided_slice(record,[0],[1]),[1])
    label = tf.reshape(tf.strided_slice(record,[1],[4]),[3])
    
    x, y = tf.train.shuffle_batch([image, label],
                                  batch_size=batch_size,
                                  capacity=batch_size*3,
                                  min_after_dequeue=batch_size*2)
    

    但现在我想换一个“数据集”的东西。我写道:

    dataset = tf.data.FixedLengthRecordDataset(filenames=data,
                                               record_bytes=4*4)
    dataset.map(_generate_x_y)
    dataset.shuffle(buffer_size=batch_size*2)
    dataset.batch(batch_size=batch_size)
    dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    x, y = iterator.get_next()
    

    使用:

    def _generate_x_y(sample):
        features = {"x": tf.FixedLenFeature([1], tf.float32),
                    "y": tf.FixedLenFeature([3], tf.float32)}
        parsed_features = tf.parse_single_example(sample,features)
        return parsed_features["x"], parsed_features["y"]
    

    我的图表如下:

    y_ = network(x)
    

    以及:

    loss = tf.losses.softmax_cross_entropy(y,y_)
    train_step = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss=loss)
    

    我的会话是:

    with tf.Session(graph=graph_train) as sess:
        tf.global_variables_initializer().run()
        for i in range(100):
            _, = sess.run([train_step])
    

    它可以很好地处理旧管道,但对于新数据集,我得到以下错误:

    File "C:/***/main.py", line 49, in <module>
    x, y = iterator.get_next()
      File "C:\***\python\framework\ops.py", line 396, in __iter__
    "`Tensor` objects are not iterable when eager execution is not "
    TypeError: `Tensor` objects are not iterable when eager execution is not enabled. To iterate over this tensor use `tf.map_fn`.
    

    感谢您的帮助:-)

    1 回复  |  直到 7 年前
        1
  •  1
  •   iga    7 年前

    一个可能是问题原因的明显问题是您没有使用转换后的数据集。基本上,而不是

    dataset = tf.data.FixedLengthRecordDataset(filenames=data,
                                               record_bytes=4*4)
    dataset.map(_generate_x_y)
    dataset.shuffle(buffer_size=batch_size*2)
    

    您应该执行以下操作:

    dataset = tf.data.FixedLengthRecordDataset(filenames=data,
                                               record_bytes=4*4)
    dataset = dataset.map(_generate_x_y)
    dataset = dataset.shuffle(buffer_size=batch_size*2)
    

    每个数据集操作都会返回一个新的、已转换的数据集。原始对象不会被以下操作修改 map shuffle