代码之家  ›  专栏  ›  技术社区  ›  Zhao Chen

如何使用tensorflow的tf.data api加载pickle文件

  •  0
  • Zhao Chen  · 技术社区  · 7 年前

    我的数据存储在磁盘上的多个pickle文件中。我想使用tensorflow的tf.data.dataset将我的数据加载到训练管道中。我的密码是:

    def _parse_file(path):
        image, label = *load pickle file*
        return image, label
    paths = glob.glob('*.pkl')
    print(len(paths))
    dataset = tf.data.Dataset.from_tensor_slices(paths)
    dataset = dataset.map(_parse_file)
    iterator = dataset.make_one_shot_iterator()
    

    问题是我不知道如何实现 _parse_file 功能。这个函数的参数, path ,是张量类型。我试过

    def _parse_file(path):
        with tf.Session() as s:
            p = s.run(path)
            image, label = pickle.load(open(p, 'rb'))
        return image, label
    

    并收到错误消息:

    InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
         [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
    

    在网上搜索了一下,我还是不知道该怎么做。我会感谢任何人给我暗示。

    1 回复  |  直到 7 年前
        1
  •  3
  •   Zhao Chen    7 年前

    我自己解决了这个问题。我应该用 tf.py_func 如此 doc .

    推荐文章