代码之家  ›  专栏  ›  技术社区  ›  Rocketq

如何在pytorch的Unet中使用PNASNet5作为编码器

  •  0
  • Rocketq  · 技术社区  · 7 年前

    我想使用PNASNet5Large作为我的Unet的编码器这里是我对PNASNet5Large的错误建议,但是为resnet工作:

    class UNetResNet(nn.Module):
    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                     pretrained=False, is_deconv=False):
            super().__init__()
            self.num_classes = num_classes
            self.dropout_2d = dropout_2d
    
            if encoder_depth == 34:
                self.encoder = torchvision.models.resnet34(pretrained=pretrained)
                bottom_channel_nr = 512
            elif encoder_depth == 101:
                self.encoder = torchvision.models.resnet101(pretrained=pretrained)
                bottom_channel_nr = 2048
            elif encoder_depth == 152: #this works
                self.encoder = torchvision.models.resnet152(pretrained=pretrained)
                bottom_channel_nr = 2048
            elif encoder_depth == 777: #coded version for the pnasnet
                self.encoder = PNASNet5Large()
                bottom_channel_nr = 4320 #this unknown for me as well
    
    
            self.pool = nn.MaxPool2d(2, 2)
            self.relu = nn.ReLU(inplace=True)
            self.conv1 = nn.Sequential(self.encoder.conv1,
                                       self.encoder.bn1,
                                       self.encoder.relu,
                                       self.pool)
    
            self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
            self.conv3 = self.encoder.layer2
            self.conv4 = self.encoder.layer3
            self.conv5 = self.encoder.layer4
            self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
    
            self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
            self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
            self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                       is_deconv)
            self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0 = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
    
        def forward(self, x):
            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            conv5 = self.conv5(conv4)
            center = self.center(conv5)
            dec5 = self.dec5(torch.cat([center, conv5], 1))
            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))
    

    1) 如何获取pnasnet有多少个底层通道。结果是:

    ...
     self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
                                in_channels_right=4320, out_channels_right=864)
            self.relu = nn.ReLU()
            self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
            self.dropout = nn.Dropout(0.5)
            self.last_linear = nn.Linear(4320, num_classes)
    

    4320 in_channels_left out_channels_left -对我来说有些新东西

    2) Resnet有4个大的层,我在我的Unet架构中使用它们和编码器,如何从pnasnet获得类似的层

    我使用的是pytorch 3.1,这是到 Pnasnet directory

    3) AttributeError:“PNASNet5Large”对象没有属性“conv1”-因此也没有conv1

    UPD:像这样尝试smth,但失败了

    类UNetPNASNet(nn.模块): 定义 初始化 (self,encoder\ u depth,num\ u classes,num\ u filters=32,dropout\ u 2d=0.2, pretrained=False,is\u deconv=False): 超级()。 初始化 自行退出\u 2d=辍学 自动编码器=PNASNet5Large() 底部通道编号=4320 自我中心=DecoderCenter(底部\u通道\u编号,num \u过滤器*8*2,num \u过滤器*8,假)

            self.dec5  =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
            self.dec4  = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec3  = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
            self.dec2  = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
            self.dec1  = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0  = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
    
    def forward(self, x):
            features = self.encoder.features(x)
            relued_features = self.encoder.relu(features)
            avg_pooled_features = self.encoder.avg_pool(relued_features)
            center = self.center(avg_pooled_features)
            dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
            dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
            dec3 = self.dec3(torch.cat([dec4, features], 1))
            dec2 = self.dec2(dec3)
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))
    

    运行时错误:给定的输入大小:(4320x4)。计算输出大小:(4320x-6x-6)。输出大小在/opt/conda/conda bld/pytorch_/work/torch/lib/THCUNN/generic/S处太小庭院平均池。特写:63

    1 回复  |  直到 7 年前
        1
  •  1
  •   iacolippo    7 年前

    所以你想用 PNASNetLarge 取而代之的是 ResNets 作为编码器 UNet 重网 使用。在你的 __init__ :

    self.pool = nn.MaxPool2d(2, 2)
    self.relu = nn.ReLU(inplace=True)
    self.conv1 = nn.Sequential(self.encoder.conv1,
                               self.encoder.bn1,
                               self.encoder.relu,
                               self.pool)
    
    self.conv2 = self.encoder.layer1
    self.conv3 = self.encoder.layer2
    self.conv4 = self.encoder.layer3
    self.conv5 = self.encoder.layer4
    

    layer4 ,这是平均池之前的最后一个块,用于resnet的大小是 之后 平均池,因此我假设 self.encoder.avgpool 之后失踪 self.conv5 = self.encoder.layer4 . 网络的前端 torchvision.models

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
    
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
    
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
    
        return x
    

    我想你也应该采取类似的解决办法 PNASNet5Large (使用体系结构直至平均池层)。

    1) 你有多少频道 PNASNet5大 但是,您需要查看平均池之后的输出张量大小,例如,通过向其提供一个伪张量。还要注意,虽然ResNet通常与输入大小一起使用 (batch_size, 3, 224, 224) ,PNASNetLarge用法 (batch_size, 3, 331, 331) .

    m = PNASNet5Large()
    x1 = torch.randn(1, 3, 331, 331)
    m.avg_pool(m.features(x1)).size()
    torch.Size([1, 4320, 1, 1])
    

    bottom_channel_nr=4320 为了你的PNASNet。

    __初始化__ forward 你的 UNet公司 . 如果你决定使用 PNASNet ,我建议你再上一节课:

    class UNetPNASNet(nn.Module):
        def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                         pretrained=False, is_deconv=False):
                super().__init__()
                self.num_classes = num_classes
                self.dropout_2d = dropout_2d
                self.encoder = PNASNet5Large()
                bottom_channel_nr = 4320
                self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
    
                self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
                self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
                self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
                self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                           is_deconv)
                self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
                self.dec0 = ConvRelu(num_filters, num_filters)
                self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
    
            def forward(self, x):
                features = self.encoder.features(x)
                relued_features = self.encoder.relu(features)
                avg_pooled_features = self.encoder.avg_pool(relued_features)
                center = self.center(avg_pooled_features)
                dec5 = self.dec5(torch.cat([center, conv5], 1))
                dec4 = self.dec4(torch.cat([dec5, conv4], 1))
                dec3 = self.dec3(torch.cat([dec4, conv3], 1))
                dec2 = self.dec2(torch.cat([dec3, conv2], 1))
                dec1 = self.dec1(dec2)
                dec0 = self.dec0(dec1)
                return self.final(F.dropout2d(dec0, p=self.dropout_2d))
    

    3) conv1

    'conv1' in list(m.modules())
    False