代码之家  ›  专栏  ›  技术社区  ›  Daniel

重写与更改pytorch权重

  •  0
  • Daniel  · 技术社区  · 3 年前

    我试图理解为什么我不能直接覆盖火炬层的权重。 考虑以下示例:

    import torch
    from torch import nn
    
    net = nn.Linear(3, 1)
    weights = torch.zeros(1,3)
    
    # Overwriting does not work
    net.state_dict()["weight"] = weights  # nothing happens
    print(f"{net.state_dict()['weight']=}")
    
    # But mutating does work
    net.state_dict()["weight"][0] = weights  # indexing works
    print(f"{net.state_dict()['weight']=}")
    
    #########
    # output
    : net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
    : net.state_dict()['weight']=tensor([[0., 0., 0.]])
    

    我很困惑 state_dict()["weight"] 只是一个火炬张量,所以我觉得我错过了一些非常明显的东西。

    1 回复  |  直到 3 年前
        1
  •  1
  •   ihdv    3 年前

    这是因为 net.state_dict() 首先创建 collections.OrderedDict 对象,然后将该模块的权重张量存储到其中,并返回dict:

    state_dict = net.state_dict()
    print(type(state_dict))    # <class 'collections.OrderedDict'>
    

    当你“覆盖”时(事实上这不是覆盖;它是 分配 在python中)这个有序的dict中,将int 0重新分配给键 'weights' 该张量中的数据没有被修改,只是没有被有序dict引用。

    当您检查张量是否被修改时:

    print(f"{net.state_dict()['weight']}")
    

    创建了一个与您修改的dict不同的新有序dict,因此您可以看到未更改的张量。

    但是,当您像这样使用索引时:

    net.state_dict()["weight"][0] = weights  # indexing works
    

    那么它就不再分配给有序dict了。相反 __setitem__ 张量的方法被调用,它允许您访问和修改底层内存。其他张量API,如 copy_ 也可以获得期望的结果。

    对的差异的清晰解释 a = b a[:] = b 什么时候 a 是张量/数组。可以在这里找到: https://stackoverflow.com/a/68978622/11790637

        2
  •  0
  •   John Stud    3 年前

    我现在还没有安装torch,但从我保存的一些代码中尝试这样的东西。我相信你需要做深度复制,就像这样

    def zero_injection(initial_weights, trained_weights, mask):
        ''' zeros all weights and then injects in masked selection '''
        # copy the weights
        initial_weights_copy = copy.deepcopy(initial_weights.state_dict())
        trained_weights_copy = copy.deepcopy(trained_weights.state_dict())
    
        # set all the values to zero
        for key, value in initial_weights_copy.items():
            initial_weights_copy[key][initial_weights_copy[key] < 0] = 0
            initial_weights_copy[key][initial_weights_copy[key] > 0] = 0
    
        state_dict = {}
        # for each key
        for key, value in initial_weights_copy.items():
            # add the key
            state_dict[key] = []
            # if False, replace initial value with trained value
            state_dict[key] = initial_weights_copy[key].cuda().where(mask[key].cuda(), trained_weights_copy[key].cuda())
    
        return state_dict