代码之家  ›  专栏  ›  技术社区  ›  Kiv

pytorch等价于index_add_u,取而代之的是最大值

  •  2
  • Kiv  · 技术社区  · 7 年前

    在Pythorn, index_add_ 张量的方法使用提供的索引张量求和:

    idx = torch.LongTensor([0,0,0,0,1,1])
    child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
    parent = torch.FloatTensor([0, 0])
    parent.index_add_(0, idx, child)
    

    前四个子值加在父值[0]中,后两个子值加在父值[1]中,因此结果是 tensor([ 19., 9.])

    但是,我需要 index_max_ 相反,它不存在于api中。有没有一种有效的方法(不需要循环或分配更多内存)?一个(坏的)循环解决方案是:

    for i in range(max(idx)+1):
        parent[i] = torch.max(child[idx == i])
    

    这就产生了 tensor([ 10., 8.]) ,但非常缓慢。

    1 回复  |  直到 7 年前
        1
  •  1
  •   benjaminplanche    7 年前

    使用索引的解决方案:

    def index_max(child, idx, num_partitions): 
        # Building a num_partition x num_samples matrix `idx_tiled`:
        partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
        partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
        idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
        idx_tiled = (idx_tiled == partition_idx).float()
        # i.e. idx_tiled[i,j] == 1 if idx[j] == i, else 0
    
        parent = idx_tiled * child
        parent, _ = torch.max(parent, dim=1)
        return parent
    

    标杆管理:

    import timeit
    
    setup = '''
    import torch
    
    def index_max_v0(child, idx, num_partitions):
        parent = torch.zeros(num_partitions)
        for i in range(max(idx) + 1):
            parent[i] = torch.max(child[idx == i])
        return parent
    
    def index_max(child, idx, num_partitions):
    
        # Building a num_partition x num_samples matrix `idx_tiled` 
        # containing for each row indices of
        partition_idx = torch.range(0, num_partitions - 1, dtype=torch.long)
        partition_idx = partition_idx.view(-1, 1).expand(num_partitions, idx.shape[0])
        idx_tiled = idx.view(1, -1).repeat(num_partitions, 1)
        idx_tiled = (idx_tiled == partition_idx).float()
    
        parent = idx_tiled * child
        parent, _ = torch.max(parent, dim=1)
        return parent
    
    idx = torch.LongTensor([0,0,0,0,1,1])
    child = torch.FloatTensor([1, 3, 5, 10, 8, 1])
    num_partitions = torch.unique(idx).shape[0]
    
    '''
    print(min(timeit.Timer('index_max_v0(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
    # > 0.05308796599274501
    print(min(timeit.Timer('index_max(child, idx, num_partitions)', setup=setup).repeat(5, 1000)))
    # > 0.024736385996220633