代码之家  ›  专栏  ›  技术社区  ›  bluesummers

numpy-如何将向量索引数组转换为掩码?

  •  1
  • bluesummers  · 技术社区  · 5 年前

    给出了一个 np.ndarray 命名 indices 用一个 n 每行中的行和可变长度向量,我要创建一个 n 行和 m 行,其中 已知值是否等于 指数 . 请注意,在 指数 参考每行索引,而不是全局矩阵索引。

    例如,给定:

    indices = np.array([
        [2, 0],
        [0],
        [4, 7, 1]
    ])
    
    # Expected output
    print(mask)
    [[ True False  True False False False False False]
     [ True False False False False False False False]
     [False  True False False  True False False  True]]
    

    预先知道(每行的最大长度 mask )不需要从中推断 指数

    通知 :这不同于将索引数组转换为一个掩码,其中索引引用生成的矩阵索引。

    2 回复  |  直到 5 年前
        1
  •  1
  •   Derlin    5 年前

    这里有一个变体:

    def create_mask(indices, m):
        mask = np.zeros((len(indices), m), dtype=bool)
        for i, idx in enumerate(indices):
            mask[i, idx] = True
        return mask
    

    用途:

    >>> create_mask(indices, 8)
    array([[ True, False,  True, False, False, False, False, False],
           [ True, False, False, False, False, False, False, False],
           [False,  True, False, False,  True, False, False,  True]])
    
        2
  •  2
  •   Divakar    5 年前

    这是一条路-

    def mask_from_indices(indices, ncols=None):
        # Extract column indices
        col_idx = np.concatenate(indices)
    
        # If number of cols is not given, infer it based on max column index
        if ncols is None:
            ncols = col_idx.max()+1
    
        # Length of indices, to be used as no. of rows in o/p
        n = len(indices)
    
        # Initialize o/p array
        out = np.zeros((n,ncols), dtype=bool)
    
        # Lengths of each index element that represents each group of col indices
        lens = np.array(list(map(len,indices)))
    
        # Use np.repeat to generate all row indices
        row_idx = np.repeat(np.arange(len(lens)),lens)
    
        # Finally use row, col indices to set True values
        out[row_idx,col_idx] = 1
        return out    
    

    样品运行-

    In [89]: mask_from_indices(indices)
    Out[89]: 
    array([[ True, False,  True, False, False, False, False, False],
           [ True, False, False, False, False, False, False, False],
           [False,  True, False, False,  True, False, False,  True]])