代码之家  ›  专栏  ›  技术社区  ›  Eduardo

使用带有tf.dataset api的python函数进行数据扩充

  •  0
  • Eduardo  · 技术社区  · 7 年前

    我正在寻找动态读取图像和应用数据增强我的图像分割问题。从我目前所看到的情况来看,最好的方法是 tf.Dataset API与 .map 功能。

    但是,从我所看到的例子来看,我认为我必须将所有的函数调整为tensorflow样式(使用 tf.cond 而不是 if 等)。问题是我需要应用一些非常复杂的函数。所以我考虑用 tf.py_func 这样地:

    import tensorflow as tf
    
    img_path_list = [...]   # List of paths to read
    mask_path_list = [...]  # List of paths to read
    
    dataset = tf.data.Dataset.from_tensor_slices((img_path_list, mask_path_list))
    
    def parse_function(img_path_list, mask_path_list):
        '''load image and mask from paths'''
        return img, mask
    
    def data_augmentation(img, mask):
        '''process data with complex logic'''
        return aug_img, aug_mask
    
    # py_func wrappers
    def parse_function_wrapper(img_path_list, mask_path_list):
        return tf.py_func(func=parse_function,
                          inp=(img_path_list, mask_path_list),
                          Tout=(tf.float32, tf.float32))
    
    def data_augmentation_wrapper(img, mask):
        return tf.py_func(func=data_augmentation,
                          inp=(img, mask),
                          Tout=(tf.float32, tf.float32))        
    
    # Maps py_funcs to dataset
    dataset = dataset.map(parse_function_wrapper,
                          num_parallel_calls=4)
    dataset = dataset.map(data_augmentation_wrapper,
                          num_parallel_calls=4)
    
    dataset = dataset.batch(32)
    iter = dataset.make_one_shot_iterator()
    imgs, labels = iter.get_next()
    

    但是,从 this answer 似乎用 py_func 因为并行性不起作用。还有别的选择吗?

    1 回复  |  直到 7 年前
        1
  •  1
  •   Alexandre Passos    7 年前

    py_u func受python gil的限制,因此在那里不会得到太多的并行性。最好的办法是用tensorflow proper编写数据扩充(或者预计算并序列化到磁盘)。

    如果您确实想用tensorflow编写它,可以尝试使用tf.contrib.autograph将简单的python ifs和for循环转换为tf.conds和tf.while_循环,这可能会大大简化您的代码。

    推荐文章