代码之家  ›  专栏  ›  技术社区  ›  I. A Ziang Yan

如何将skip与tf.data.Dataset数据集tensorflow中的api或将批处理设置为0

  •  1
  • I. A Ziang Yan  · 技术社区  · 7 年前

    以下是开始代码:

    import tensorflow as tf
    import numpy as np
    import time
    
    index1 = tf.Variable(-1, dtype=tf.int32, trainable=False)
    index2 = tf.Variable(0, dtype=tf.int32, trainable=False)
    
    starting_point1 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
    starting_point2 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
    starting_point3 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
    starting_point4 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
    starting_point5 = tf.Variable(tf.constant(0, dtype=tf.int64), trainable=False)
    
    def starting_point_add_1_0(starting_point1, n):
        assignment1 = tf.assign_add(starting_point1, n)
        assignment2 = tf.assign_add(starting_point1, 0)
        return assignment1, assignment2
    
    mod_op1 = tf.mod(index1, 5)
    mod_op1 = tf.Print(input_=mod_op1, data=[mod_op1], message="mod_op1")
    
    mod_op2 = tf.mod(index2, 5)
    mod_op2 = tf.Print(input_=mod_op2, data=[mod_op2], message="mod_op2")
    
    condition = tf.logical_and(tf.equal(mod_op1, 0), tf.equal(mod_op2, 1))
    condition = tf.Print(input_=condition, data=[condition], message="condition")
    
    ass1_1, ass1_2 = tf.cond(condition, 
                             lambda: starting_point_add_1_0(starting_point1, 1),
                             lambda: starting_point_add_1_0(starting_point1, 0))
    
    # ass2 = tf.cond(tf.cast(index1 % 5 == 1 or index2 % 5 == 2, dtype=tf.bool), 
    #               lambda: tf.assign_add(starting_point2, 1),
    #               lambda: tf.assign_add(starting_point2, 0))
    
    # ass3 = tf.cond(index1 % 5 == 2 or index2 % 5 == 3, 
    #               lambda: tf.assign_add(starting_point3, 1),
    #               lambda: tf.assign_add(starting_point3, 0))
    
    # ass4 = tf.cond(index1 % 5 == 3 or index2 % 5 == 4, 
    #               lambda: tf.assign_add(starting_point4, 1),
    #               lambda: tf.assign_add(starting_point4, 0))
    
    # ass5 = tf.cond(index1 % 5 == 4 or index2 % 5 == 0, 
    #               lambda: tf.assign_add(starting_point5, 1),
    #               lambda: tf.assign_add(starting_point5, 0))
    
    data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
    data2 = tf.data.Dataset.range(21, 40).skip(starting_point2)
    data3 = tf.data.Dataset.range(41, 60).skip(starting_point3)
    data4 = tf.data.Dataset.range(61, 80).skip(starting_point4)
    data5 = tf.data.Dataset.range(81, 100).skip(starting_point5)
    
    iterator1 = data1.make_initializable_iterator()
    iterator2 = data2.make_initializable_iterator()
    iterator3 = data3.make_initializable_iterator()
    iterator4 = data4.make_initializable_iterator()
    iterator5 = data5.make_initializable_iterator()
    
    d1 = iterator1.get_next()
    d2 = iterator2.get_next()
    d3 = iterator3.get_next()
    d4 = iterator4.get_next()
    d5 = iterator5.get_next()
    
    data_ = tf.stack((d1, d2, d3, d4, d5), axis=0)
    
    ass6 = tf.assign_add(index1, 1)
    ass7 = tf.assign_add(index2, 1)
    
    with tf.control_dependencies([ass6, ass7]):
        data = tf.gather_nd(data_, indices=[[index1 % 5], [index2 % 5]])
    
    init_op = tf.global_variables_initializer()
    
    
    with tf.Session() as sess:
        sess.run(init_op)
        sess.run(iterator1.initializer)
        sess.run(iterator2.initializer)
        sess.run(iterator3.initializer)
        sess.run(iterator4.initializer)
        sess.run(iterator5.initializer)
    
        try:
            for i in range(20):
                t1, t2, t3, s1 = sess.run([data, index1, index2, starting_point1])
                print(t1, t2, t3, ".....", s1)
                sess.run([ass1_1, ass1_2])
    
        except tf.errors.OutOfRangeError:
            print("error")
    

    因此,这段代码的主要目的是能够遍历我使用 tf.data.Dataset.range 功能。我想从第一个数据集中取一个元素,从第二个数据集中取一个元素。那么,我想考虑一下 data2 data3 ,那么 数据3 data4 ,那么 数据4 data5 ,那么 数据5 data1 以此类推。

    我这样想,结果如下:

    [ 1 21] 0 1 ..... 0
    [22 42] 1 2 ..... 1
    [43 63] 2 3 ..... 1
    [64 84] 3 4 ..... 1
    [85  5] 4 5 ..... 1
    [ 6 26] 5 6 ..... 1
    [27 47] 6 7 ..... 2
    [48 68] 7 8 ..... 2
    [69 89] 8 9 ..... 2
    [90 10] 9 10 ..... 2
    [11 31] 10 11 ..... 2
    [32 52] 11 12 ..... 3
    [53 73] 12 13 ..... 3
    [74 94] 13 14 ..... 3
    [95 15] 14 15 ..... 3
    [16 36] 15 16 ..... 3
    [37 57] 16 17 ..... 4
    [58 78] 17 18 ..... 4
    [79 99] 18 19 ..... 4
    error
    

    现在我刚刚测试了第一个数据集的起点,看看它是否工作正常,但是我得到了上面的结果。我以为我会得到一个 2 而不是 5 ,但事实并非如此。好像跳过几乎没有效果。

    我希望有人能帮我解决这个问题。

    最后,我在这方面的工作,因为我有一个自定义的数据集,将不适合GPU内存,因此我希望实现类似的东西。

    我试过以下测试 skip() 函数,结果如下:

    starting_point1 = tf.Variable(tf.constant(2, dtype=tf.int64), trainable=False)
    
    data1 = tf.data.Dataset.range(1, 20).skip(starting_point1)
    iterator1 = data1.make_initializable_iterator()
    d1 = iterator1.get_next()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(iterator1.initializer)
    
        try:
            for i in range(10):
                print(sess.run(d1))
        except tf.errors.OutOfRangeError:
            print("error")
    

    输出:

    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    

    因此,如何强制迭代器在每个训练步骤跳过?

    非常感谢您的帮助!!

    0 回复  |  直到 7 年前
    推荐文章