代码之家  ›  专栏  ›  技术社区  ›  gasoon

如何用pytorch只更新网络中的一些特定张量?

  •  3
  • gasoon  · 技术社区  · 7 年前


    从第11代开始,我想改变,更新整个模型。
    我怎样才能达到目标?

    2 回复  |  直到 7 年前
        1
  •  6
  •   Shai    7 年前

    您可以为每个参数组设置学习速率(和一些其他元参数)。您只需要根据需要对参数进行分组。
    例如,为conv层设置不同的学习速率:

    import torch
    import itertools
    from torch import nn
    
    conv_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                                 if isinstance(m, nn.Conv2d)])
    other_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                                  if not isinstance(m, nn.Conv2d)]) 
    optimizer = torch.optim.SGD([{'params': other_params},
                                 {'params': conv_params, 'lr': 0}],  # set init lr to 0
                                lr=lr_for_model)
    

    您可以稍后访问优化器 param_groups

    看到了吗 per-parameter options 更多信息。

        2
  •  0
  •   weiyixie    7 年前

    非常简单,因为PYTORCH可以动态地重新创建计算图形。

    for p in resnet.parameters():
        p.requires_grad = False # this will freeze the module from training suppose that resnet is one of your module
    

    如果有多个模块,只需在其上循环。10点以后,你只要打个电话

    for p in network.parameters():
        p.requires_grad = True # suppose your whole network is the 'network' module
    
    推荐文章