pytorchでPSPNet(その2 Loss関数を作る)

書いてる理由

  • pytorchを基礎からもう一回

参考

pytorchによる発展ディープラーニング

https://arxiv.org/pdf/1612.01105.pdf

詳細

pytorch_work/loss.py at master · ys201810/pytorch_work · GitHub

今日はLoss関数の部分
PSPNetは前回の通り、outputが二つある。一つは最後まで通してUpsamplingしたもの、もう一つは特徴抽出の途中のものをUpsamplingしたもの[Auxiliary]。
この二つをLoss関数として定義するが、重みをつけて合成していて、最後まで通った方が重み1.0、Auxは0.4という形で合成

具体的には以下。

import torch
from torch import nn
import torch.nn.functional as F


class PSPLoss(nn.Module):
    def __init__(self, aux_weight=0.4):
        super(PSPLoss, self).__init__()
        self.aux_weight = aux_weight

    def forward(self, outputs, targets):
        loss = F.cross_entropy(outputs[0], targets, reduction='mean')
        loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')

        return loss + self.aux_weight * loss_aux

forwardはnetworkを通って出てきた予測結果でtargetsは正解データ。
これを475*475の全ピクセルでcross entropyで計算してる。
aux_weightは0.4固定で loss + 0.4 * aux_lossをreturnすることで、auxが0.4倍の重みを持ってLossとして計算されている。

今日は以上。
学習部分は思ったより多くなったので分割。