我想使用Tensorflow的Dataset API来读取变量长度列表的TFRecords文件。这是我的密码。
def _int64_feature(value):
# value must be a numpy array.
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def main1():
# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
[0, 50, 89, 147, 196],
[0, 38, 79, 157],
[0, 49, 89, 147, 177],
[0, 32, 73, 145]])
writer = tf.python_io.TFRecordWriter('file')
for i in range(a.shape[0]): # i = 0 ~ 4
x_train = a[i]
feature = {'i': _int64_feature(np.array([i])), 'data': _int64_feature(x_train)}
# Create an example protocol buffer
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize to string and write on the file
writer.write(example.SerializeToString())
writer.close()
# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
i = (example.features.feature['i'].int64_list.value)
data = (example.features.feature['data'].int64_list.value)
#data = np.fromstring(data_string, dtype=np.int64)
print(i, data)
# Use Dataset API to read the TFRecord file.
def _parse_function(example_proto):
keys_to_features = {'i' :tf.FixedLenFeature([], tf.int64),
'data':tf.FixedLenFeature([], tf.int64)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
return parsed_features['i'], parsed_features['data']
ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
print(i.eval())
print(data.eval())
检查TFRecord文件
[0] [0, 54, 91, 153, 177]
[1] [0, 50, 89, 147, 196]
[2] [0, 38, 79, 157]
[3] [0, 49, 89, 147, 177]
[4] [0, 32, 73, 145]
但当我尝试使用Dataset API读取TFRecord文件时,它显示了以下错误。
tensorflow。蟒蛇框架错误\u impl。InvalidArgumentError:名称:
,键:数据,索引:0。int64值的数目!=预期。
值大小:5,但输出形状:[]
非常感谢。
更新:
我尝试使用以下代码读取带有Dataset API的TFRecord,但两者都失败了。
def _parse_function(example_proto):
keys_to_features = {'i' :tf.FixedLenFeature([], tf.int64),
'data':tf.VarLenFeature(tf.int64)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
return parsed_features['i'], parsed_features['data']
ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
print(sess.run([i, data]))
或
def _parse_function(example_proto):
keys_to_features = {'i' :tf.VarLenFeature(tf.int64),
'data':tf.VarLenFeature(tf.int64)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
return parsed_features['i'], parsed_features['data']
ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
print(sess.run([i, data]))
以及错误:
回溯(最近一次调用):文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py”,
第468行,在make\u tensor\u proto中
str\u values=[proto\u values中x的compat.as\u bytes(x)]文件/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py“,
第468行,英寸
str\u值=[proto\u值中x的compat.as\u字节(x)]文件/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/compat.py“,
第65行,以as\u字节为单位
(bytes\u或\u text,))类型错误:应为二进制或unicode字符串,获取
在处理上述异常期间,发生了另一个异常:
回溯(最近一次调用last):文件“2tfrecord.py”,第126行,in
main1()文件“2tfrecord.py”,第72行,在main1中
迭代器=ds。map(\u parse\u函数)。生成\u one\u shot\u iterator()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”,
712行,地图中
返回MapDataset(self,map\u func)文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”,
第1385行,英寸
初始化
赛尔夫_映射函数。add\u to\u graph(ops.get\u default\u graph())文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”,
第486行,在add\u to\u图形中
赛尔夫_创建\u definition\u if\u needed()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”,
第321行,在\u create\u definition\u if\u needed中
赛尔夫_创建\u definition\u if\u needed\u impl()文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py”,
第338行,在\u create\u definition\u if needed\u impl中
输出=自身_func(*inputs)文件“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”,
tf\U map\U func中的第1376行
展平\u ret=[ops.convert\u to\u tensor(t)for t in nest.Flatte(ret)]文件
“/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset\u ops.py”,
第1376行,in
展平\u ret=[ops.convert\u to\u tensor(t)for t in nest.Flatte(ret)]文件
“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py”,
第836行,in convert\u to\u tensor
as\u ref=False)File“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py”,
第926行,in internal\u convert\u to\u tensor
ret=conversion\u func(value,dtype=dtype,name=name,as\u ref=as\u ref)文件
“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant\u op.py”,
第229行,在\u constant\u tensor\u conversion\u函数中
return constant(v,dtype=dtype,name=name)File“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant\u op.py”,
第208行,恒定
值,dtype=dtype,shape=shape,verify\u shape=verify\u shape))文件
“/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor\u util.py”,
第472行,在make\u tensor\u proto中
“支持的类型。”%(type(values),values))类型错误:无法将类型的对象转换为张量。
内容:
SparseTensor(索引=张量(“ParseSingleExample/Slice\u Indexes\u i:0”,
形状=(?,1),数据类型=int64),
值=张量(“ParseSingleExample/ParseExample/ParseExample:3”,
形状=(?,),数据类型=int64),
密集形状=张量(“ParseSingleExample/Squeeze\u shape\u i:0”,形状=(1,),
dtype=int64))。考虑将图元强制转换为受支持的类型。
Python版本:3.5.2
Tensorflow版本:1.4.1