書いてる理由
- pytorchを基礎からもう一回
参考
https://arxiv.org/pdf/1612.01105.pdf
詳細
pytorch_work/network.py at master · ys201810/pytorch_work · GitHub
PSPNetでセマンティックセグメンテーションする。
昨日はPyramid Poolingの箇所だったので、今日はDecodeとAuxLOSSの箇所。
PSPNetは、上の画像の通り、inputの画像(a)をFeature Map(b)で特徴抽出して、Pyramid Pooling Moudle(c)で4つの異なるサイズのFeature Mapを元のFeature MAPの大きさにUpsamplingした上でconcatしてから予測(d)する。
この(d)の箇所
class DecodePSPFeature(nn.Module): def __init__(self, height, width, n_classes): super(DecodePSPFeature, self).__init__() self.height = height self.width = width self.cbr = Conv2DBatchNormRelu(in_channels=4096, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, bias=False) self.dropout = nn.Dropout2d(p=0.1) self.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.cbr(x) x = self.dropout(x) x = self.classification(x) output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True) return output class AuxiliaryPSPlayers(nn.Module): def __init__(self, in_channels, height, width, n_classes): super(AuxiliaryPSPlayers, self).__init__() self.height = height self.width = width self.cbr = Conv2DBatchNormRelu(in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False) self.dropout = nn.Dropout2d(p=0.1) # pointwise convolutionでout_class分の出力に変換。 self.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0) def forward(self, x): x = self.cbr(x) x = self.dropout(x) x = self.classification(x) output = F.interpolate(x, size=(self.height, self.width), mode='bilinear', align_corners=True) return output
この二つは似ているのだけど、Decodeの方は最後まで通した上でUpsampling。
AuxはresidualBlock1を通った時点でUpsampling。
Decodeは最後まで通っているのでchannel数が4096がinでAuxはresidualBlock1の出力時点の256がinに指定されている。
interpolateでUpsamplingする前に、両方nn.Conv2d(in_channels=, out_channels=n_classes, kernel_size=1, stride=1, padding=0)を通す。
kernel_size=1のConvolutionは、縦横を全捜査してout_channelsがクラス数分のpoint wise convolutionになっている。
これと正解を比べてロスを作っていく。
今日はここまで。次はロス関数と学習部分。