書いてる理由
- pytorchを基礎からもう一回
参考
詳細
PSPNetでセマンティックセグメンテーションする。
Pascal VOCのデータのダウンロードに残り二日ってどういうこと〜〜???たかが2G1時間でダウンロードしてくれえええええ
ダウンロード完了してないから、Dataloaderまでをとりあえず書いといた。
pytorch_work/3 at master · ys201810/pytorch_work · GitHub
ポイントだけ
# VOCデータを使ってバッチ処理時に画像とアノテーションを取得するためのclass class VOCDataset(data.Dataset): def __init__(self, img_list, anno_list, phase, transform): self.img_list = img_list self.anno_list = anno_list self.phase = phase self.transform = transform def __len__(self): return len(self.img_list) def __getitem__(self, index): img, anno_class_img = self.pull_item(index) return img, anno_class_img def pull_item(self, index): image_file_path = self.img_list[index] img = Image.open(image_file_path) anno_file_path = self.anno_list[index] anno_class_img = Image.open(anno_file_path) img, anno_class_img = self.transform(self.phase, img, anno_class_img) return img, anno_class_img
データローダーを作るためのclassは、torch.utils.data.DatasetをExtendしてclassを作り、
def initでリストを格納。
def getitemでindex番号を引数で受けて、initで作成したリスト[index]でデータを取得する感じで作る。
あとは、train.pyでdataloaderを作ってるが、trainはshuffleをTrueにvalはshuffleをFalseにするのは通例的な感じで、学習時はランダムにすることでリストの上の方に傾向が固まるとかの事象を防ぐためで、valはパラメータの更新に寄与しないからshuffleする必要なしって感じ。
train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
明日はダウンロードできてっかなぁ。。