代码之家  ›  专栏  ›  技术社区  ›  Lion Lai

如何使用Dataset API读取变量长度列表的TFRecords文件?

  •  8
  • Lion Lai  · 技术社区  · 7 年前

    我想使用Tensorflow的Dataset API来读取变量长度列表的TFRecords文件。这是我的密码。

    def _int64_feature(value):
        # value must be a numpy array.
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
    def main1():
        # Write an array to TFrecord.
        # a is an array which contains lists of variant length.
        a = np.array([[0, 54, 91, 153, 177],
                     [0, 50, 89, 147, 196],
                     [0, 38, 79, 157],
                     [0, 49, 89, 147, 177],
                     [0, 32, 73, 145]])
    
        writer = tf.python_io.TFRecordWriter('file')
    
        for i in range(a.shape[0]): # i = 0 ~ 4
            x_train = a[i]
            feature = {'i': _int64_feature(np.array([i])), 'data': _int64_feature(x_train)}
    
            # Create an example protocol buffer
            example = tf.train.Example(features=tf.train.Features(feature=feature))
    
            # Serialize to string and write on the file
            writer.write(example.SerializeToString())
    
        writer.close()
    
        # Check TFRocord file.
        record_iterator = tf.python_io.tf_record_iterator(path='file')
        for string_record in record_iterator:
            example = tf.train.Example()
            example.ParseFromString(string_record)
    
            i = (example.features.feature['i'].int64_list.value)
            data = (example.features.feature['data'].int64_list.value)
            #data = np.fromstring(data_string, dtype=np.int64)
            print(i, data)
    
        # Use Dataset API to read the TFRecord file.
        def _parse_function(example_proto):
            keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                                'data':tf.FixedLenFeature([], tf.int64)}
            parsed_features = tf.parse_single_example(example_proto, keys_to_features)
            return parsed_features['i'], parsed_features['data']
    
        ds = tf.data.TFRecordDataset('file')
        iterator = ds.map(_parse_function).make_one_shot_iterator()
        i, data = iterator.get_next()
        with tf.Session() as sess:
            print(i.eval())
            print(data.eval())
    

    检查TFRecord文件

    [0] [0, 54, 91, 153, 177]
    [1] [0, 50, 89, 147, 196]
    [2] [0, 38, 79, 157]
    [3] [0, 49, 89, 147, 177]
    [4] [0, 32, 73, 145]
    

    但当我尝试使用Dataset API读取TFRecord文件时,它显示了以下错误。

    tensorflow。蟒蛇框架错误\u impl。InvalidArgumentError:名称: ,键:数据,索引:0。int64值的数目!=预期。 值大小:5,但输出形状:[]

    非常感谢。
    更新: 我尝试使用以下代码读取带有Dataset API的TFRecord,但两者都失败了。

    def _parse_function(example_proto):
        keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                            'data':tf.VarLenFeature(tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['i'], parsed_features['data']
    
    ds = tf.data.TFRecordDataset('file')
    iterator = ds.map(_parse_function).make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(sess.run([i, data]))
    

    def _parse_function(example_proto):
        keys_to_features = {'i'   :tf.VarLenFeature(tf.int64),
                            'data':tf.VarLenFeature(tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['i'], parsed_features['data']
    
    ds = tf.data.TFRecordDataset('file')
    iterator = ds.map(_parse_function).make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(sess.run([i, data]))
    

    以及错误:

    回溯(最近一次调用):文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py”, 第468行,在make\u tensor\u proto中 str\u values=[proto\u values中x的compat.as\u bytes(x)]文件/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py“, 第468行,英寸 str\u值=[proto\u值中x的compat.as\u字节(x)]文件/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/compat.py“, 第65行,以as\u字节为单位 (bytes\u或\u text,))类型错误:应为二进制或unicode字符串,获取

    在处理上述异常期间,发生了另一个异常:

    回溯(最近一次调用last):文件“2tfrecord.py”,第126行,in main1()文件“2tfrecord.py”,第72行,在main1中 迭代器=ds。map(\u parse\u函数)。生成\u one\u shot\u iterator()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”, 712行,地图中 返回MapDataset(self,map\u func)文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”, 第1385行,英寸 初始化 赛尔夫_映射函数。add\u to\u graph(ops.get\u default\u graph())文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”, 第486行,在add\u to\u图形中 赛尔夫_创建\u definition\u if\u needed()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”, 第321行,在\u create\u definition\u if\u needed中 赛尔夫_创建\u definition\u if\u needed\u impl()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”, 第338行,在\u create\u definition\u if needed\u impl中 输出=自身_func(*inputs)文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”, tf\U map\U func中的第1376行 展平\u ret=[ops.convert\u to\u tensor(t)for t in nest.Flatte(ret)]文件 “/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”, 第1376行,in 展平\u ret=[ops.convert\u to\u tensor(t)for t in nest.Flatte(ret)]文件 “/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py”, 第836行,in convert\u to\u tensor as\u ref=False)File“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py”, 第926行,in internal\u convert\u to\u tensor ret=conversion\u func(value,dtype=dtype,name=name,as\u ref=as\u ref)文件 “/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant\u op.py”, 第229行,在\u constant\u tensor\u conversion\u函数中 return constant(v,dtype=dtype,name=name)File“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant\u op.py”, 第208行,恒定 值,dtype=dtype,shape=shape,verify\u shape=verify\u shape))文件 “/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py”, 第472行,在make\u tensor\u proto中 “支持的类型。”%(type(values),values))类型错误:无法将类型的对象转换为张量。 内容: SparseTensor(索引=张量(“ParseSingleExample/Slice\u Indexes\u i:0”, 形状=(?,1),数据类型=int64), 值=张量(“ParseSingleExample/ParseExample/ParseExample:3”, 形状=(?,),数据类型=int64), 密集形状=张量(“ParseSingleExample/Squeeze\u shape\u i:0”,形状=(1,), dtype=int64))。考虑将图元强制转换为受支持的类型。

    Python版本:3.5.2
    Tensorflow版本:1.4.1

    2 回复  |  直到 7 年前
        1
  •  12
  •   Lion Lai    7 年前

    经过数小时的搜索和尝试,我相信答案终于出现了。下面是我的代码。

    def _int64_feature(value):
        # value must be a numpy array.
        return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))
    
    # Write an array to TFrecord.
    # a is an array which contains lists of variant length.
    a = np.array([[0, 54, 91, 153, 177],
                  [0, 50, 89, 147, 196],
                  [0, 38, 79, 157],
                  [0, 49, 89, 147, 177],
                  [0, 32, 73, 145]])
    
    writer = tf.python_io.TFRecordWriter('file')
    
    for i in range(a.shape[0]): # i = 0 ~ 4
        x_train = np.array(a[i])
        feature = {'i'   : _int64_feature(np.array([i])), 
                   'data': _int64_feature(x_train)}
    
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))
    
        # Serialize to string and write on the file
        writer.write(example.SerializeToString())
    
    writer.close()
    
    # Check TFRocord file.
    record_iterator = tf.python_io.tf_record_iterator(path='file')
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)
    
        i = (example.features.feature['i'].int64_list.value)
        data = (example.features.feature['data'].int64_list.value)
        print(i, data)
    
    # Use Dataset API to read the TFRecord file.
    filenames = ["file"]
    dataset = tf.data.TFRecordDataset(filenames)
    def _parse_function(example_proto):
        keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                            'data':tf.VarLenFeature(tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return tf.sparse_tensor_to_dense(parsed_features['i']), \
               tf.sparse_tensor_to_dense(parsed_features['data'])
    # Parse the record into tensors.
    dataset = dataset.map(_parse_function)
    # Shuffle the dataset
    dataset = dataset.shuffle(buffer_size=1)
    # Repeat the input indefinitly
    dataset = dataset.repeat()  
    # Generate batches
    dataset = dataset.batch(1)
    # Create a one-shot iterator
    iterator = dataset.make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(sess.run([i, data]))
        print(sess.run([i, data]))
        print(sess.run([i, data]))
    

    没有什么值得注意的。
    1、本 SO 这个问题很有帮助。
    2. tf.VarLenFeature 将返回SparseTensor,因此,使用 tf.sparse_tensor_to_dense 转换为稠密张量是必要的。
    3、在我的代码中, parse_single_example() 无法替换为 parse_example() 这让我烦了一天。我不知道为什么 parse\u示例() 不起作用。如果有人知道原因,请告诉我。

        2
  •  2
  •   iga    7 年前

    错误很简单。你的 data 不是 FixedLenFeature 它是 VarLenFeature . 更换您的线路:

     'data':tf.FixedLenFeature([], tf.int64)}
    

    具有

     'data':tf.VarLenFeature(tf.int64)}
    

    还有,当你打电话的时候 print(i.eval()) print(data.eval()) 您正在调用迭代器两次。第一个 print 将打印 0 ,但第二行将打印第二行的值 [ 0, 50, 89, 147, 196] . 你可以做到 print(sess.run([i, data])) 得到 i 数据 来自同一行。