書いてる理由
- pytorchを基礎からもう一回
参考
詳細
以下のコードでDatasetまで準備
# coding=utf-8 import os import glob import torch.utils.data as data from PIL import Image from vgg_finetune import ImageTransform # データセットはtorch.utils.data.Datasetを継承したクラスを作成する。 class HymenopteraDataset(data.Dataset): # __init__で画像のリストとデータオーギュメンテーションなどの定義が含まれるtransformerなどを受け取る。 def __init__(self, file_list, transform=None, phase='train'): self.file_list = file_list self.transform = transform self.phase = phase def __len__(self): return len(self.file_list) # __getitem__でバッチ中に返却する値を定義する。 def __getitem__(self, index): img_path = self.file_list[index] img = Image.open(img_path) img_transformed = self.transform(img, self.phase) if self.phase == 'train': label = img_path.split('/')[-2] else: label = img_path.split('/')[-2] if label == 'ants': label = 0 elif label == 'bees': label = 1 return img_transformed, label def make_datapath_list(phase='train'): rootpath = os.path.join('../', '..', '1_image_classification', 'data', 'hymenoptera_data') target_path = os.path.join(rootpath, phase, '**', '*.jpg') image_list = [] for path in glob.glob(target_path): image_list.append(path) return image_list def main(): train_list = make_datapath_list('train') val_list = make_datapath_list('val') size = 224 mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) train_dataset = HymenopteraDataset(file_list=train_list, transform=ImageTransform(size, mean, std), phase='train') val_dataset = HymenopteraDataset(file_list=val_list, transform=ImageTransform(size, mean, std), phase='val') index = 0 print(train_dataset.__getitem__(index)[0].size()) print(train_dataset.__getitem__(index)[1]) if __name__ == '__main__': main()
実行すると、以下の出力が返ってきて、画像がchannel数, 縦, 横のサイズでラベルが0[ants]をさすデータが取得できていることがわかる。
torch.Size([3, 224, 224])
0
次は、これを使ってデータローダーを作り、バッチごとにローダーでデータを取得して学習するところ。