pytorchのDatasetの準備

書いてる理由

  • pytorchを基礎からもう一回

参考

pytorchによる発展ディープラーニング

詳細

以下のコードでDatasetまで準備

# coding=utf-8
import os
import glob
import torch.utils.data as data
from PIL import Image
from vgg_finetune import ImageTransform


# データセットはtorch.utils.data.Datasetを継承したクラスを作成する。
class HymenopteraDataset(data.Dataset):
    # __init__で画像のリストとデータオーギュメンテーションなどの定義が含まれるtransformerなどを受け取る。
    def __init__(self, file_list, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.phase = phase

    def __len__(self):
        return len(self.file_list)

    # __getitem__でバッチ中に返却する値を定義する。
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img_transformed = self.transform(img, self.phase)

        if self.phase == 'train':
            label = img_path.split('/')[-2]
        else:
            label = img_path.split('/')[-2]

        if label == 'ants':
            label = 0
        elif label == 'bees':
            label = 1

        return img_transformed, label


def make_datapath_list(phase='train'):
    rootpath = os.path.join('../', '..', '1_image_classification', 'data', 'hymenoptera_data')
    target_path = os.path.join(rootpath, phase, '**', '*.jpg')
    image_list = []
    for path in glob.glob(target_path):
        image_list.append(path)

    return image_list


def main():
    train_list = make_datapath_list('train')
    val_list = make_datapath_list('val')
    size = 224
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    train_dataset = HymenopteraDataset(file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
    val_dataset = HymenopteraDataset(file_list=val_list, transform=ImageTransform(size, mean, std), phase='val')

    index = 0
    print(train_dataset.__getitem__(index)[0].size())
    print(train_dataset.__getitem__(index)[1])

if __name__ == '__main__':
    main()

実行すると、以下の出力が返ってきて、画像がchannel数, 縦, 横のサイズでラベルが0[ants]をさすデータが取得できていることがわかる。

torch.Size([3, 224, 224])
0

次は、これを使ってデータローダーを作り、バッチごとにローダーでデータを取得して学習するところ。