pytorchでPSPNet(その0 Pascal VOCデータのDataloader周り)

書いてる理由

  • pytorchを基礎からもう一回

参考

pytorchによる発展ディープラーニング

詳細

PSPNetでセマンティックセグメンテーションする。
Pascal VOCのデータのダウンロードに残り二日ってどういうこと〜〜???たかが2G1時間でダウンロードしてくれえええええ

ダウンロード完了してないから、Dataloaderまでをとりあえず書いといた。
pytorch_work/3 at master · ys201810/pytorch_work · GitHub

ポイントだけ

# VOCデータを使ってバッチ処理時に画像とアノテーションを取得するためのclass
class VOCDataset(data.Dataset):
    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img, anno_class_img = self.pull_item(index)
        return img, anno_class_img

    def pull_item(self, index):
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path)

        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path)

        img, anno_class_img = self.transform(self.phase, img, anno_class_img)

        return img, anno_class_img

データローダーを作るためのclassは、torch.utils.data.DatasetをExtendしてclassを作り、
def initでリストを格納。
def getitemでindex番号を引数で受けて、initで作成したリスト[index]でデータを取得する感じで作る。

あとは、train.pyでdataloaderを作ってるが、trainはshuffleをTrueにvalはshuffleをFalseにするのは通例的な感じで、学習時はランダムにすることでリストの上の方に傾向が固まるとかの事象を防ぐためで、valはパラメータの更新に寄与しないからshuffleする必要なしって感じ。

    train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

明日はダウンロードできてっかなぁ。。