代码之家  ›  专栏  ›  技术社区  ›  Fatemeh

如何从使用nn训练和保存的模型加载检查点。DataParallel到不使用nn的模型上。DataParallel?

  •  0
  • Fatemeh  · 技术社区  · 1 年前

    如何从使用nn训练和保存的模型加载检查点。DataParallel到不使用nn的模型上。DataParallel?我试图删除“模块”从state_dict,但我现在遇到了一个不同的错误。这是到ResNet-50的链接 checkpoints .

    from torchvision.models import ResNet50_Weights, resnet50
    
    # Load the model
    model = resnet50()
    checkpoint_path = 'C:/res50-debiased.pth.tar'
    checkpoint = torch.load(checkpoint_path)
    
    state_dict = checkpoint['state_dict']
    
    # creating new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
    
    

    这会产生一个错误

    RuntimeError: Error(s) in loading state_dict for ResNet:
        Unexpected key(s) in state_dict: "bn1.aux_bn.weight", "bn1.aux_bn.bias", "bn1.aux_bn.running_mean", "bn1.aux_bn.running_var", "bn1.aux_bn.num_batches_tracked", "layer1.0.bn1.aux_bn.weight", "layer1.0.bn1.aux_bn.bias", "layer1.0.bn1.aux_bn.running_mean", "layer1.0.bn1.aux_bn.running_var", "layer1.0.bn1.aux_bn.num_batches_tracked", "layer1.0.bn2.aux_bn.weight", "layer1.0.bn2.aux_bn.bias", "layer1.0.bn2.aux_bn.running_mean", "layer1.0.bn2.aux_bn.running_var", "layer1.0.bn2.aux_bn.num_batches_tracked", "layer1.0.bn3.aux_bn.weight", "layer1.0.bn3.aux_bn.bias", "layer1.0.bn3.aux_bn.running_mean", "layer1.0.bn3.aux_bn.running_var", "layer1.0.bn3.aux_bn.num_batches_tracked", "layer1.0.downsample.1.aux_bn.weight", "layer1.0.downsample.1.aux_bn.bias", "layer1.0.downsample.1.aux_bn.running_mean", "layer1.0.downsample.1.aux_bn.running_var", "layer1.0.downsample.1.aux_bn.num_batches_tracked", "layer1.1.bn1.aux_bn.weight", "layer1.1.bn1.aux_bn.bias", "layer1.1.bn1.aux_bn.running_mean", "layer1.1.bn1.aux_bn.running_var", "layer1.1.bn1.aux_bn.num_batches_tracked", "layer1.1.bn2.aux_bn.weight", "layer1.1.bn2.aux_bn.bias", "layer1.1.bn2.aux_bn.running_mean", "layer1.1.bn2.aux_bn.running_var", "layer1.1.bn2.aux_bn.num_batches_tracked", "layer1.1.bn3.aux_bn.weight", "layer1.1.bn3.aux_bn.bias", 
    
    

    像这样正常加载

    # Load the model
    model = resnet50()
    checkpoint_path = 'C:/res50-debiased.pth.tar'
    checkpoint = torch.load(checkpoint_path)
    
    state_dict = checkpoint['state_dict']
    
    model.load_state_dict(state_dict)
    

    给出错误 Unexpected key(s) in state_dict: "module.conv1.weight",

    RuntimeError: Error(s) in loading state_dict for ResNet:
        Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", ...
    
    Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.bn1.aux_bn.weight", "module.bn1.aux_bn.bias", "module.bn1.aux_bn.running_mean", "module.bn1.aux_bn.running_var", "module.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean", "module.layer1.0.bn1.running_var", "module.layer1.0.bn1.num_batches_tracked", "module.layer1.0.bn1.aux_bn.weight", "module.layer1.0.bn1.aux_bn.bias", "module.layer1.0.bn1.aux_bn.running_mean", "module.layer1.0.bn1.aux_bn.running_var", "module.layer1.0.bn1.aux_bn.num_batches_tracked", "module.layer1.0.conv2.weight", "module.layer1.0.bn2.weight", "module.layer1.0.bn2.bias", "module.layer1.0.bn2.running_mean", "module.layer1.0.bn2.running_var", "module.layer1.0.bn2.num_batches_tracked", "module.layer1.0.bn2.aux_bn.weight", "module.layer1.0.bn2.aux_bn.bias", "module.layer1.0.bn2.aux_bn.running_mean", "module.layer1.0.bn2.aux_bn.running_var", "module.layer1.0.bn2.aux_bn.num_batches_tracked", "module.layer1.0.conv3.weight", "module.layer1.0.bn3.weight", "module.layer1.0.bn3.bias", "module.layer1.0.bn3.running_mean", "module.layer1.0.bn3.running_var", "module.layer1.0.bn3.num_batches_tracked", "module.layer1.0.bn3.aux_bn.weight", "module.layer1.0.bn3.aux_bn.bias", "module.layer1.0.bn3.aux_bn.running_mean", "module.layer1.0.bn3.aux_bn.running_var", "module.layer1.0.bn3.aux_bn.num_batches_tracked", "module.layer1.0.downsample.0.weight", "module.layer1.0.downsample.1.weight", "module.layer1.0.downsample.1.bias", "module.layer1.0.downsample.1.running_mean", "module.layer1.0.downsample.1.running_var", "module.
    
    

    非常感谢。

    1 回复  |  直到 1 年前
        1
  •  0
  •   Ivan Jeffrey Zhao    1 年前

    你做了正确的事情删除了 "module." 前缀,但剩下的问题来自于这样一个事实 resnet50 模型是使用中定义的自定义规范化层初始化的 aux_bn.py MixBatchNorm2d 。您可以看到ResNet正在初始化 here .
    这将产生类型的键 "bn*.aux_bn" .

    您的代码应该在正确初始化的情况下运行:

    checkpoint = torch.load(checkpoint_path)
    state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    
    model = resnet50(num_classes=1_000, norm_layer=MixBatchNorm2d)
    model.load_state_dict(state_dict)