書いてる理由
- pytorchを基礎からもう一回
参考
https://arxiv.org/pdf/1612.01105.pdf
詳細
pytorch_work/loss.py at master · ys201810/pytorch_work · GitHub
今日はLoss関数の部分
PSPNetは前回の通り、outputが二つある。一つは最後まで通してUpsamplingしたもの、もう一つは特徴抽出の途中のものをUpsamplingしたもの[Auxiliary]。
この二つをLoss関数として定義するが、重みをつけて合成していて、最後まで通った方が重み1.0、Auxは0.4という形で合成
具体的には以下。
import torch from torch import nn import torch.nn.functional as F class PSPLoss(nn.Module): def __init__(self, aux_weight=0.4): super(PSPLoss, self).__init__() self.aux_weight = aux_weight def forward(self, outputs, targets): loss = F.cross_entropy(outputs[0], targets, reduction='mean') loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean') return loss + self.aux_weight * loss_aux
forwardはnetworkを通って出てきた予測結果でtargetsは正解データ。
これを475*475の全ピクセルでcross entropyで計算してる。
aux_weightは0.4固定で loss + 0.4 * aux_lossをreturnすることで、auxが0.4倍の重みを持ってLossとして計算されている。
今日は以上。
学習部分は思ったより多くなったので分割。