代码之家  ›  专栏  ›  技术社区  ›  ZHANG Juenjie

tf.收集\n用了很多次真的很慢

  •  1
  • ZHANG Juenjie  · 技术社区  · 7 年前

    我想要张量流中的损失函数,它是许多元素的复杂组合。例如,此代码:

    import tensorflow as tf
    import numpy as np
    import time
    
    input_layer = tf.placeholder(tf.float64, shape=[64,4])
    output_layer = input_layer + 0.5*tf.tanh(tf.Variable(tf.random_uniform(shape=[64,4],\
                                                           minval=-1,maxval=1,dtype=tf.float64)))
    
    # random_combination is 2-d numpy array of the form:
    # [[32, 34, 23, 56],[23,54,33,21],...]
    random_combination = np.random.randint(64, size=(210000000, 4))
    
    # a collector to collect the values 
    collector=[]
    
    print('start looping')   
    print(time.asctime(time.localtime(time.time())))
    
    # loop through random_combination and pick the elements of output_layer
    for i in range(len(random_combination)):
        [i,j,k,l] = [random_combination[i][0],random_combination[i][1],\
                     random_combination[i][2],random_combination[i][3]]
    
        # pick the needed element from output_layer
        f1 = tf.gather_nd(output_layer,[i,0])
        f2 = tf.gather_nd(output_layer,[i,2])
        f3 = tf.gather_nd(output_layer,[i,3])
        f4 = tf.gather_nd(output_layer,[i,4])
    
        tf1 = f1+1
        tf2 = f2+1
        tf3 = f3+1
        tf4 = f4+1
        collector.append(0.3*tf.abs(f1*f2*tf3*tf4-tf1*tf2*f3*f4))
    
    print('end looping')   
    print(time.asctime(time.localtime(time.time())))
    
    # loss function
    loss = tf.add_n(collector)
    

    在我的电脑上大约需要50分钟。 我的问题是,这是正确的方式做编码在张量流? 或者有更省时的方法来索引元素?

    0 回复  |  直到 7 年前
    推荐文章