代码之家  ›  专栏  ›  技术社区  ›  sanjeev mk

通过索引从Pytorch或Numpy 2D数组中快速删除多行的方法

  •  2
  • sanjeev mk  · 技术社区  · 1 年前

    我有一个numpy数组(相当于Pytorch张量)的形状 Nx3 。我还有一个与行相对应的索引列表,我想从这个张量中删除这些索引。此索引列表称为 remove_ixs . N 非常大,大约有500万行,并且 删除_ixs 长50k。我现在的做法如下:

    mask = [i not in remove_ixs for i in range(my_array.shape[0])]
    new_array = my_array[mask,:]
    

    但第一条线只是没有终止,需要很长时间。以上是numpy代码。一个等效的Pytorch代码也适用于我。

    有没有更快的方法可以用numpy或pytorch做到这一点?

    2 回复  |  直到 1 年前
        1
  •  2
  •   jared    1 年前

    您可以创建一个初始 mask (布尔)数组 True 对于要删除的元素,然后将其反转以给出 面具 的元素。

    remove_mask = np.zeros(my_array.shape[0], dtype=bool)
    remove_mask[remove_ixs] = True
    mask = ~remove_mask
        
    new_array = my_array[mask, :]
    

    或者全部启动 真的 相反:

    mask = np.ones(my_array.shape[0], dtype=bool)
    mask[remove_ixs] = False
        
    new_array = my_array[mask, :]
    

    出于某种原因,对于较小的阵列,第一个版本的速度更快。

        2
  •  0
  •   user24714692    1 年前

    您可以使用 np.delete() :

    import numpy as np
    
    A = np.random.rand(5000000, 3)
    remove_ixs = np.random.choice(5000000, 50000, replace=False)
    B = np.delete(A, remove_ixs, axis=0)
    
    print(len(B))
    
    

    打印

    4950000