在我看来,你用发电机使你的生活变得不必要的复杂。
我将这样实现您的输入管道:
def parse_file_tf(filename):
return tf.py_func(parse_file, [filename], [tf.float32, tf.float32])
# version with map
files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.map(parse_file_tf, num_parallel_calls=N)
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.batch(batch_size).shuffle(shuffle_size).prefetch(2)
it = dataset.make_one_shot_iterator()
为了测试它,我定义了一个假人
parse_file
作为:
i=0
def parse_file(f):
global i
i += 1
return np.asarray([i]*i, dtype=np.float32), np.asarray([i]*i, dtype=np.float32) # mimicks variable-length examples_x, examples_y
sess = tf.Session()
try:
while True:
x, y = it.get_next()
vx, vy = sess.run([x,y])
print(vx)
print(vy)
except tf.errors.OutOfRangeError:
pass
sess.close()
运行上面的代码打印:
[2. 3. 2. 1. 3. 3.]
[2. 3. 2. 1. 3. 3.]
基本上,我把并行化问题留给
map
,在这里我可以传递它应该运行的线程数。不需要生成器遍历范围和那些额外的复杂度。
我选择了地图
parallel_interleave
因为后者需要生成一个
Dataset
解析\u文件
并行交叉
如果您缓慢地生成值(例如,通过应用
tf.data.TFRecordDataset
文件名列表),但如果数据集适合内存,则
地图
.
tf.py_func
局限性,它们不影响你训练的网络,只影响输入管道。理想情况下,你将有一个不同的管道为你的培训和你的网络的最终使用。在后一种情况下,您只需注意这些限制,而对于培训(除非您对分布式培训和/或跨机器移动培训进行了非常具体的操作),您是相当安全的。
如果您的JSON文件非常大,并且它们的内容无法放入内存,那么您可以使用生成器,但与您开始使用的方法略有不同。
yield
一次只有一张唱片。那么,发电机必须是你的
功能。举个例子,假设你有
i = 3
def parse_file(filename):
global i
i += 1
ctr = 0
while ctr < i:
yield ctr, ctr
在这种情况下,管道将如下所示:
def wrap_generator(filename):
return tf.data.Dataset.from_generator(parse_file(filename), [tf.int32, tf.int32])
files = tf.data.Dataset.from_tensor_slices(files_to_process)
dataset = files.apply(tf.contrib.data.parallel_interleave(wrap_generator, cycle_length=N))
dataset = dataset.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
dataset = dataset.shuffle(shuffle_size).batch(batch_size).prefetch(2)
it = dataset.make_one_shot_iterator()
并行交叉
数据集
从中提取值的实例。
将其送入与上述打印相同的样本循环:
[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]
[6. 5. 4. 4. 6. 5. 6. 6. 5. 4. 6. 4. 5. 5. 6.]