代码之家  ›  专栏  ›  技术社区  ›  le4m

在Tensorflow中访问张量中的条件索引

  •  0
  • le4m  · 技术社区  · 7 年前

    假设我有张量 X x 尺寸为K。很容易获得所有样品的第K个元素: X[1:batch_size,k] . 但是假设我需要访问x的第k个元素 k_list = [1, 2, ..., 2] ,我所知道的唯一一个访问x的第k个元素的方法是

    out=[X[i,k_list[i]] for all i in range(len(k_list))]
    

    问题是这会让我的代码变慢。我们能优化这段代码吗?

    注*:事实上 k_list 作为占位符。大小 np.shape(X)=(batch_size,K) , np.shape(k_list)=(batch_size,) , np.maximum(k_list)=K-1, np.minimum(k_list)=0 np.shape(out)=(batch_size,1)

    1 回复  |  直到 7 年前
        1
  •  1
  •   DomJack    7 年前

    如果我正确理解你的问题,你在寻找 gather_nd

    i0 = tf.range(batch_size, dtype=tf.int32)
    indices = tf.stack((i0, k_list), axis=1)
    out = tf.gather_nd(X, indices)
    
    推荐文章