代码之家  ›  专栏  ›  技术社区  ›  Farshid Rayhan

pytorch resnet在哪里添加值?

  •  0
  • Farshid Rayhan  · 技术社区  · 6 年前

    我正在研究resnet,我发现了一个使用加号进行跳过连接的实现。像下面这样

    Class Net(nn.Module):
        def __init__(self):
            super(Net, self).__int_() 
                self.conv = nn.Conv2d(128,128)
    
        def forward(self, x):
            out = self.conv(x) // line 1 
            x = out + x    // skip connection  // line 2
    

    现在我已经调试并打印了第1行前后的值。输出如下:

    在第1行之后
    X=[1128,32,32]
    输出=[1128,32,32]

    第2行之后
    x=[1128,32,32]//静止

    参考链接: https://github.com/kuangliu/pytorch-cifar/blob/bf78d3b8b358c4be7a25f9f9438c842d837801fd/models/resnet.py#L62

    我的问题是它在哪里增加了价值??我是说之后

    X=输出+X

    操作,增值在哪里?

    张量格式是[批,通道,高度,宽度]。

    1 回复  |  直到 6 年前
        1
  •  1
  •   benjaminplanche    6 年前

    3x3 [3, 3] 1x128x32x32

    import torch
    
    out = torch.ones((3, 3))
    x = torch.eye(3, 3)
    res = out + x
    
    print(out.shape)
    # torch.Size([3, 3])
    print(out)
    # tensor([[ 1.,  1.,  1.],
    #         [ 1.,  1.,  1.],
    #         [ 1.,  1.,  1.]])
    print(x.shape)
    # torch.Size([3, 3])
    print(x)
    # tensor([[ 1.,  0.,  0.],
    #         [ 0.,  1.,  0.],
    #         [ 0.,  0.,  1.]])
    print(res.shape)
    # torch.Size([3, 3])
    print(res)
    # tensor([[ 2.,  1.,  1.],
    #         [ 1.,  2.,  1.],
    #         [ 1.,  1.,  2.]])