代码之家  ›  专栏  ›  技术社区  ›  Matt

如何重塑Tensorflow数据集中的数据?

  •  0
  • Matt  · 技术社区  · 5 年前

    我正在编写一个数据管道,将成批的时间序列和相应的标签输入到需要3D输入形状的LSTM模型中。我目前有以下几点:

    def split(window):
        return window[:-label_length], window[-label_length]
    
    dataset = tf.data.Dataset.from_tensor_slices(data.sin)
    dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
    dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
    dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.cache()
    dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    

    最终形成的形状 for x, y in dataset.take(1): x.shape 是(32,20),其中32是批量大小,20是序列长度,但我需要一个(32,20,1)的形状,其中额外的维度表示特征。

    我的问题是如何重塑,理想情况下是在 split 被传递到函数中 dataset.map 在缓存数据之前运行?

    0 回复  |  直到 5 年前
        1
  •  1
  •   thushv89    5 年前

    这很简单。在拆分函数中执行此操作

    def split(window):
        return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]