pytorchでPSPNet(その3 学習用スクリプト作成1)

書いてる理由

  • pytorchを基礎からもう一回

参考

pytorchによる発展ディープラーニング

https://arxiv.org/pdf/1612.01105.pdf

詳細

pytorch_work/train.py at master · ys201810/pytorch_work · GitHub

これまで作成したネットワーク/Loss関数を使って学習するスクリプトを作る。
これも二つに分けて書こうかな〜
全体は一番下に。

パーツパーツで説明。

configの設定とデータローダーの作成

batch_size: ミニバッチ中に取得する画像の数。GPUのメモリと相談。
n_classes: VOC2012のセグメンテーションが背景込みで21種類なので21。
epochs: 学習データ全体を何回使って学習するかの数。
color_mean/color_std: RGBそれぞれの平均と標準偏差。データを正規化するために利用。
data_root_path: VOC2012のデータをダウンロードしたパスを指定。
[train|val]dataset: VOC2012のデータを画像とアノテーションを取得できるようにしたDatasetクラスで作成。
[train|val]
dataloader: datasetを指定して、batch_sizeとshuffleを指定。shuffleはtrainは偏らないようにTrueでvalはパラメータの更新に寄与しないのでFalse。
dataloaders_dict: 参考にしている著者の流儀でdictにしてまとめる。

    batch_size = 4
    n_classes = 21
    epochs = 30
    color_mean = (0.485, 0.456, 0.406)
    color_std = (0.229, 0.224, 0.225)

    data_root_path = os.path.join('/path', 'to', 'VOCdevkit', 'VOC2012')
    train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(data_root_path)
    train_dataset = VOCDataset(train_img_list, train_anno_list, phase='train',
                               transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))
    val_dataset = VOCDataset(val_img_list, val_anno_list, phase='val',
                               transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))

    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

ネットワークの準備とプレトレインのload、最後の予測する層の付け替え。

net = PSPNet(n_class=150)は利用するプレトレインが150クラスのセグメンテーション用のものらしく、これに沿ってn_class=150で定義。
net.load()でプレトレインモデルの読み込み。
net.decode_featureとかは、最後の出力層がプレトレインで用意したものが150分類なのに対し、今回は21classなので、out_channels=21になる様に付け替えてる。
付け替えた層を初期化するために、weights_init関数を定義して、Convolutionはxavierの初期化を使い、biasはconstで初期化。

    net = PSPNet(n_classes=150)
    net.load_state_dict(torch.load('./weights/pspnet50_ADE20K.pth'))

    net.decode_feature.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1,
                                                  padding=0)
    net.aux.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)

    net.decode_feature.classification.apply(weights_init)
    net.aux.classification.apply(weights_init)

最適化関数の定義

optimizerは、全体に対して定義することが多いけど、以下の様に層ごとにlearning rateを指定することもできる。
今回の場合、最後の出力箇所だけ最初から学習なので、最後の方だけ大きめに動かしてそれ以外は小さく動かす様な設定で。
weight decayのスケジューラーは、{1 - (epoch / max_epoch)}^0.9で設定していて、徐々に移動範囲を小さくする様な設定。

    optimizer = optim.SGD([
        {'params': net.feature_conv.parameters(), 'lr': 1e-3},
        {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
        {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
        {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
        {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
        {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
        {'params': net.decode_feature.parameters(), 'lr': 1e-2},
        {'params': net.aux.parameters(), 'lr': 1e-2},
    ], momentum=0.9, weight_decay=0.0001)

    # スケジューラーの設定
    def lambda_epoch(epoch):
        max_epoch = 30
        return math.pow((1 - epoch / max_epoch), 0.9)

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

Loss関数の設定と実際の学習

criterionは前回作成したPSPLossを指定。
これまでに作った、network/dataloaders_dict/criterion/scheduler/optimizer/epochsで学習を開始する。
train_modelが長いので、次に回す。

    criterion = PSPLoss(aux_weight=0.4)

    train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=epochs)
def main():
    batch_size = 4
    n_classes = 21
    color_mean = (0.485, 0.456, 0.406)
    color_std = (0.229, 0.224, 0.225)

    data_root_path = os.path.join('/path', 'to', 'VOCdevkit', 'VOC2012')
    train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(data_root_path)
    train_dataset = VOCDataset(train_img_list, train_anno_list, phase='train',
                               transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))
    val_dataset = VOCDataset(val_img_list, val_anno_list, phase='val',
                               transform=DataTransform(input_size=475, color_mean=color_mean, color_std=color_std))

    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}

    net = PSPNet(n_classes=150)
    net.load_state_dict(torch.load('./weights/pspnet50_ADE20K.pth'))

    net.decode_feature.classification = nn.Conv2d(in_channels=512, out_channels=n_classes, kernel_size=1, stride=1,
                                                  padding=0)
    net.aux.classification = nn.Conv2d(in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)

    net.decode_feature.classification.apply(weights_init)
    net.aux.classification.apply(weights_init)

    print('ネットワークの設定完了。学習済みの重みのロード終了')

    # optimizerをネットワークの層の名前ごとにlrを変えて定義
    optimizer = optim.SGD([
        {'params': net.feature_conv.parameters(), 'lr': 1e-3},
        {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
        {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
        {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
        {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
        {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
        {'params': net.decode_feature.parameters(), 'lr': 1e-2},
        {'params': net.aux.parameters(), 'lr': 1e-2},
    ], momentum=0.9, weight_decay=0.0001)

    # スケジューラーの設定
    def lambda_epoch(epoch):
        max_epoch = 30
        return math.pow((1 - epoch / max_epoch), 0.9)

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

    criterion = PSPLoss(aux_weight=0.4)

    train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=epochs)