pytorchでPSPNet(その1 ネットワークを作る[DecodeとAuxiliary Loss])

書いてる理由

  • pytorchを基礎からもう一回

参考

pytorchによる発展ディープラーニング

https://arxiv.org/pdf/1612.01105.pdf

詳細

pytorch_work/network.py at master · ys201810/pytorch_work · GitHub

PSPNetでセマンティックセグメンテーションする。
昨日はPyramid Poolingの箇所だったので、今日はDecodeとAuxLOSSの箇所。

f:id:raishi12:20200312221551p:plain

PSPNetは、上の画像の通り、inputの画像(a)をFeature Map(b)で特徴抽出して、Pyramid Pooling Moudle(c)で4つの異なるサイズのFeature Mapを元のFeature MAPの大きさにUpsamplingした上でconcatしてから予測(d)する。
この(d)の箇所

class DecodePSPFeature(nn.Module):
    def __init__(self, height, width, n_classes):
        super(DecodePSPFeature, self).__init__()

        self.height = height
        self.width = width

        self.cbr = Conv2DBatchNormRelu(in_channels=4096, out_channels=512, kernel_size=3, stride=1, padding=1,
                                       dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True)

        return output


class AuxiliaryPSPlayers(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super(AuxiliaryPSPlayers, self).__init__()

        self.height = height
        self.width = width

        self.cbr = Conv2DBatchNormRelu(in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=1,
                                       dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        # pointwise convolutionでout_class分の出力に変換。
        self.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True)

        return output

この二つは似ているのだけど、Decodeの方は最後まで通した上でUpsampling。
AuxはresidualBlock1を通った時点でUpsampling。
Decodeは最後まで通っているのでchannel数が4096がinでAuxはresidualBlock1の出力時点の256がinに指定されている。
interpolateでUpsamplingする前に、両方nn.Conv2d(in_channels=, out_channels=n_classes, kernel_size=1, stride=1, padding=0)を通す。
kernel_size=1のConvolutionは、縦横を全捜査してout_channelsがクラス数分のpoint wise convolutionになっている。
これと正解を比べてロスを作っていく。
今日はここまで。次はロス関数と学習部分。