代码之家  ›  专栏  ›  技术社区  ›  DsCpp

收集比输入数据更高维度的索引?

  •  0
  • DsCpp  · 技术社区  · 7 年前

    阅读 Dynamic Graph CNN for Learning on Point Clouds 代码,我看到这段代码:

      idx_ = tf.range(batch_size) * num_points
      idx_ = tf.reshape(idx_, [batch_size, 1, 1]) 
    
      point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
      point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  <--- what happens here?
      point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)
    

    调试线路时,我确保调光器

    point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1) 
    // indices are (32,1024,20) after broadcasting
    

    阅读 tf.gather doc 我无法理解在维度高于输入维度的情况下函数的作用。

    1 回复  |  直到 7 年前
        1
  •  1
  •   LI Xuhong    7 年前

    numpy中的等效函数是 np.take ,一个简单的例子:

    import numpy as np
    
    params = np.array([4, 3, 5, 7, 6, 8])
    
    # Scalar indices; (output is rank(params) - 1), i.e. 0 here.
    indices = 0
    print(params[indices])
    
    # Vector indices; (output is rank(params)), i.e. 1 here.
    indices = [0, 1, 4]
    print(params[indices])  # [4 3 6]
    
    # Vector indices; (output is rank(params)), i.e. 1 here.
    indices = [2, 3, 4]
    print(params[indices])  # [5 7 6]
    
    # Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
    indices = np.array([[0, 1, 4], [2, 3, 4]])
    print(params[indices])  # equivalent to np.take(params, indices, axis=0)
    # [[4 3 6]
    # [5 7 6]]
    

    在你的情况下, indices 高于 params ,所以输出是等级( 帕拉姆 +秩(+秩) 指数 )-1(即2+3-1=4,即(32,1024,20,3))。这个 - 1 是因为 tf.gather(axis=0) axis 此时必须是0级(所以是一个标量)。所以 指数 获取第一维度的元素( axis=0 )以“花哨”的索引方式。

    编辑 :

    简而言之,在你的情况下,(如果我没有误解代码)

    • point_cloud IS(32、1024、3),32批1024分,其中3批 协调。
    • nn_idx IS(32,1024,20),20个邻居的指数 32批1024分。索引用于索引 点云 .
    • nn_idx+idx_ (32,1024,20),20个邻居的指数 32批1024分。索引用于索引 point_cloud_flat .
    • point_cloud_neighbors 最后是(321024, 20,3),同 nnIdx+ixxi 除了那个 点云邻居 他们的3个坐标是 nn_idx+idx_ 只是他们的指数。
    推荐文章