我只是想弄清楚同一个问题,在谷歌上搜索时发现了你的问题。
通过一些挖掘:
-
register_forward_hook
添加在中
this PR
7年
-
register_module_forward_hook
3年前添加于
this PR
前者似乎需要在每个模块的基础上进行设置,而后者是一个全局挂钩,您可以为每个模块设置一次。
test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))
正在查看
blame
对于
register_module_forward_hook
显示此
relevant issue
3个月前的更多细节。
听起来后者更适合您的情况。特别是,考虑最近提交的评论,因为它使其与上下文管理器兼容。
例如,您可以使用它通过使用这样的上下文管理器来计算每个层上的每个示例的激活规范
@contextmanager
def module_hook(hook: Callable):
handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
yield
handle.remove()
def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
A = inputs[0].detach()
layer.norms2 = (A * A).sum(dim=1)
with module_hook(compute_norms):
outputs = model(data)
print("layer", "norms squared")
for name, layer in model.named_modules():
if not name:
continue
print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
完整代码来自
colab
from contextlib import contextmanager
from typing import Callable, Tuple
import torch
import torch.nn as nn
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data = torch.tensor([[1., 0.], [1., 1.]]).to(device)
bs = data.shape[0] # batch size
def simple_model(d, num_layers):
"""Creates simple linear neural network initialized to 2*identity"""
layers = []
for i in range(num_layers):
layer = nn.Linear(d, d, bias=False)
layer.weight.data.copy_(2 * torch.eye(d))
layers.append(layer)
return torch.nn.Sequential(*layers)
norms = [torch.zeros(bs).to(device)]
def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
assert len(inputs) == 1, "multi-input layer??"
A = inputs[0].detach()
layer.norms2 = (A * A).sum(dim=1)
model = simple_model(2, 3).to(device)
@contextmanager
def module_hook(hook: Callable):
handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
yield
handle.remove()
with module_hook(compute_norms):
outputs = model(data)
np.testing.assert_allclose(model[0].norms2.cpu(), [1, 2])
np.testing.assert_allclose(model[1].norms2.cpu(), [4, 8])
np.testing.assert_allclose(model[2].norms2.cpu(), [16, 32])
print(f"{'layer':20s}: {'norms squared'}")
for name, layer in model.named_modules():
if not name:
continue
print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
# print(name, layer.norms2)
assert not torch.nn.modules.module._global_forward_hooks, "Some hooks remain"