我正在编写一个数据管道,将成批的时间序列和相应的标签输入到需要3D输入形状的LSTM模型中。我目前有以下几点:
def split(window):
return window[:-label_length], window[-label_length]
dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
最终形成的形状
for x, y in dataset.take(1): x.shape
是(32,20),其中32是批量大小,20是序列长度,但我需要一个(32,20,1)的形状,其中额外的维度表示特征。
我的问题是如何重塑,理想情况下是在
split
被传递到函数中
dataset.map
在缓存数据之前运行?