代码之家  ›  专栏  ›  技术社区  ›  Tolga Aktas

PyTorch中的register_forward_hook和register_module_forward_hook之间有什么区别?

  •  0
  • Tolga Aktas  · 技术社区  · 2 年前

    正如标题所示,我试图理解这两个函数在PyTorch中作为前向钩子的功能是如何实现的?我看到regisfter_module_forward_hook添加了一个全局状态,我假设这意味着所有前向挂钩都有一个函数。是这样吗?或者它的功能与更常用的register_forward_hook有何不同?

    我最终要写的是从给定网络的所有层计算相同的统计信息,因此用作钩子的函数在所有层中都是相同的。后者是更好的选择吗?

    我还没有尝试过使用它们,因为我正在努力找出哪一种更适合我的情况。

    0 回复  |  直到 2 年前
        1
  •  1
  •   Yaroslav Bulatov    2 年前

    我只是想弄清楚同一个问题,在谷歌上搜索时发现了你的问题。

    通过一些挖掘:

    • register_forward_hook 添加在中 this PR 7年
    • register_module_forward_hook 3年前添加于 this PR

    前者似乎需要在每个模块的基础上进行设置,而后者是一个全局挂钩,您可以为每个模块设置一次。

     test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))
    

    正在查看 blame 对于 register_module_forward_hook 显示此 relevant issue 3个月前的更多细节。

    听起来后者更适合您的情况。特别是,考虑最近提交的评论,因为它使其与上下文管理器兼容。

    例如,您可以使用它通过使用这样的上下文管理器来计算每个层上的每个示例的激活规范

    @contextmanager
    def module_hook(hook: Callable):
        handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
        yield
        handle.remove()
    
    def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
        A = inputs[0].detach()
        layer.norms2 = (A * A).sum(dim=1)
    
    with module_hook(compute_norms):
        outputs = model(data)
    
    print("layer", "norms squared")
    for name, layer in model.named_modules():
        if not name:
            continue
        print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
    
    

    enter image description here

    完整代码来自 colab

    from contextlib import contextmanager
    from typing import Callable, Tuple
    
    import torch
    import torch.nn as nn
    
    import numpy as np
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    data = torch.tensor([[1., 0.], [1., 1.]]).to(device)
    bs = data.shape[0]  # batch size
    
    def simple_model(d, num_layers):
        """Creates simple linear neural network initialized to 2*identity"""
        layers = []
        for i in range(num_layers):
            layer = nn.Linear(d, d, bias=False)
            layer.weight.data.copy_(2 * torch.eye(d))
            layers.append(layer)
        return torch.nn.Sequential(*layers)
    
    norms = [torch.zeros(bs).to(device)]
    
    def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
        assert len(inputs) == 1, "multi-input layer??"
        A = inputs[0].detach()
        layer.norms2 = (A * A).sum(dim=1)
    
    model = simple_model(2, 3).to(device)
    
    @contextmanager
    def module_hook(hook: Callable):
        handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
        yield
        handle.remove()
    
    with module_hook(compute_norms):
        outputs = model(data)
    
    np.testing.assert_allclose(model[0].norms2.cpu(), [1, 2])
    np.testing.assert_allclose(model[1].norms2.cpu(), [4, 8])
    np.testing.assert_allclose(model[2].norms2.cpu(), [16, 32])
    
    print(f"{'layer':20s}: {'norms squared'}")
    for name, layer in model.named_modules():
        if not name:
            continue
        print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
    #     print(name, layer.norms2)
    
    assert not torch.nn.modules.module._global_forward_hooks, "Some hooks remain"