PytorchでBERTのネットワークを作る(その4 学習)

書いてる理由

参考

pytorchによる発展ディープラーニング

概要

ここまでで用意したネットワーク、プレトレインモデルで学習してポジネガ判定をする。

コード

github.com

詳細

今回は学習部分。ここまで作ってきたネットワークは、特定のタスクに依存しない形で作っており、これを使ってclassification向けにカスタマイズしてポジネガ判定をする。
利用するデータはtransformerの時と同じくIMDb。このデータの扱いはここでは記述しない。以下の記事あたりで

raishi12.hatenablog.com

簡単な流れは、以下。
1. 学習設定とbert用のtokenizerの用意
2. データセットの用意とデータローダーの作成
3. ボキャブラリーを使う用意
4. 分類向けのネットワークの用意とプレトレインモデルのロード
5. 学習させる層の指定
6. 最適化関数と損失関数の作成
7. 学習
8. 予測

1. 学習設定とbert用のtokenizerの用意

max_lengthは1つ文章の利用単語数。batch_sizeはバッチサイズ、num_epochはエポック数、data_pathIMDbやプレトレインモデル、ネットワークのconfファイルが格納されているディレクトリパス。
torchtextで文章とラベルそれぞれの定義をしていく。TEXTとLABELの設定している内容はそれぞれコメントを参照。
tokenizerはbert用のword-pieceで用意する。

    # config setting
    max_length = 256
    batch_size = 32
    num_epochs = 3
    data_path = os.path.join('/path/to/data_directory')

    # Pytorchでテキストを扱うためのデータセットの設定
    TEXT = torchtext.data.Field(sequential=True,                        # データの長さが可変である時True
                                tokenize=tokenizer_with_preprocessing,  # 文章を読み込んだ時に処理する前処理関数の指定
                                use_vocab=True,                         # 単語をボキャブラリーに追加するか
                                lower=True,                             # アルファベットを小文字に変換するか
                                include_lengths=True,                   # 文章の単語数のデータを保持するか
                                batch_first=True,                       # バッチサイズをテンソルの最初の次元で扱う
                                fix_length=max_length,                  # 指定した数になるようにPADDINGする
                                init_token='[CLS]',                     # 文頭
                                eos_token='[SEP]',                      # 文末
                                pad_token='[PAD]',                      # padding
                                unk_token='[UNK]')                      # 未知語
    LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

    bert_vocab_file = os.path.join(data_path, 'vocab', 'bert-base-uncased-vocab.txt')
    bert_tokenizer = BertTokenizer(vocab_file=bert_vocab_file, do_lower_case=True)

2. データセットの用意とデータローダーの作成

データセットとデータローダーもtransformerの時と全く一緒。
torchtext.data.TabularDataset.splitsでpathにディレクトリ名、trainは学習用のtsvファイル名、testはテスト用のtsvファイル名、format/fieldはそれぞれtsvとTEXT,LABELを指定。
今回のデータはIMDb_train.tsvにtrainとvalを混ぜているので、splitで学習8割、検証2割にランダムで割って、それぞれのデータセットを作成。
作成したデータセットtorchtext.data.Iteratorイテレータブルなオブジェクトにしてデータローダーとして定義。

    train_val_ds, test_ds = torchtext.data.TabularDataset.splits(
        path=os.path.join(data_path, 'data'), train='IMDb_train.tsv', test='IMDb_test.tsv', format='tsv',
        fields=[('Text', TEXT), ('Label', LABEL)])

    train_ds, val_ds = train_val_ds.split(split_ratio=0.8, random_state=random.seed(1234))

    # データローダーの用意
    train_dl = torchtext.data.Iterator(train_ds, batch_size=batch_size, train=True)
    val_dl = torchtext.data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
    test_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False)

    dataloaders_dict = {'train': train_dl, 'val': val_dl}

3. ボキャブラリーを使う用意

ここが少しトリッキーで、文章を扱うTEXTオブジェクトでTEXT.vocabを使いたいが、build_vocabを一度しておかないと.vocabが利用できない。
今回、ボキャブラリーもプレトレインモデルのものを利用するので、IMDbのデータでのボキャブラリーは不要なのだけど、上の理由からtrain_dsで一度build_vocabしてプレトレインモデルのvocabで上書きすることをしている。

    # TEXTでvocab関数が利用できるようにbuildをするが、今回利用したいのはプレトレインのvocabなのでbuildするが、プレトレインで再度上書きする。ボキャブラリーを使う用意
    vocab_vert, ids_to_tokens_bert = load_vocab(vocab_file=os.path.join(data_path, 'vocab', 'bert-base-uncased-vocab.txt'))

    TEXT.build_vocab(train_ds, min_freq=1)
    TEXT.vocab.stoi = vocab_vert

4. 分類向けのネットワークの用意とプレトレインモデルのロード

bertのネットワーク用の設定ファイルを読んで、前回までのやり方でネットワークを用意する。
さらに用意したネットワークにプレトレインモデルをloadする。
上で書いたように、今回は分類向けの出力としたいので、bertのネットワークの最後に2つの出力をする層をくっつけて、その出力を利用するように修正する。
class BertForIMBdでこれを定義していて、前回作ったモデルのあとに、nn.Linear(in_features=768, out_features=2)をくっつけて、出力を2種類出すようにしている。
これまで通り、[CLS]という文章の最初のベクトルだけを分類に利用するベクトルとして出力するところに注意。

    # BERTのコンフィグファイルからモデルを作成
    conf_file = os.path.join(data_path, 'weights', 'bert_config.json')
    json_file = open(conf_file, 'r')
    json_object = json.load(json_file)

    config = AttrDict(json_object)
    bert_net = BertModel(config)

    # プレトレインのロード
    bert_net = set_learned_params(bert_net, weights_path=os.path.join(data_path, 'weights', 'pytorch_model.bin'))

    # IMDb向けに最後に全結合層を追加したネットワークを作成
    net = BertForIMBd(bert_net)
    net.train()
    print('ネットワーク設定完了')

    # IMBd向けのファインチューニングで最終層以外を学習しないようにする
    for name, param in net.named_parameters():
        param.requires_grad = False
    for name, param in net.bert.encoder.layer[-1].named_parameters():
        param.requires_grad = True
    for name, param in net.cls.named_parameters():
        param.requires_grad = True


class BertForIMBd(nn.Module):
    def __init__(self, net):
        super(BertForIMBd, self).__init__()

        self.bert = net

        # 最終層でポジかネガかを判別するための全結合層
        self.cls = nn.Linear(in_features=768, out_features=2)

        # 重みの初期化
        nn.init.normal_(self.cls.weight, std=0.02)
        nn.init.normal_(self.cls.bias, 0)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False):
        if attention_show_flg:
            encoded_layers, pooled_output, attention_probs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg)
        else:
            encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg)

        vec_0 = encoded_layers[:, 0, :]  # [CLS]の文字のベクトルだけを取り出す
        vec_0 = vec_0.view(-1, 768)  # sizeをバッチサイズ, hidden_sizeに変換
        out = self.cls(vec_0)

        if attention_show_flg:
            return out, attention_probs
        else:
            return out

5. 学習させる層の指定

BERTはめっちゃ大きいネットワークなので、全部学習すると時間がかかる。なので、今回は最終層だけ学習するように設定した。
最初に全ての層でparam.requires_grad = Falseで学習時の勾配の更新を無しにしておいて、最終層だけTrueに変える。
ForIMDbの中でself.clsというのを作っていて、これが追加した層を学習する意味合い。
これは何気に重要なチップスで、class Foo(nn.Module):で作ったコンストラクタの名前で層の名前とかパラメータにアクセスできる。こういうの結構忘れがち。

    # IMBd向けのファインチューニングで最終層以外を学習しないようにする
    for name, param in net.named_parameters():
        param.requires_grad = False
    for name, param in net.bert.encoder.layer[-1].named_parameters():
        param.requires_grad = True
    for name, param in net.cls.named_parameters():
        param.requires_grad = True

6. 最適化関数と損失関数の作成

optimizerはAdamでloss functionは分類なのでcrossentropyで。

    # 最適化関数の定義
    optimizer = optim.Adam([
        {'params': net.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
        {'params': net.cls.parameters(), 'lr': 5e-5}
    ], betas=(0.9, 0.999))

    # 損失関数の定義
    criterion = nn.CrossEntropyLoss()

7. 学習

学習はほぼtransformerと一緒。載せると長いのでやめる。train_model関数を参照してくだしぃ

8. 予測

最後にテストデータで予測。テストデータもデータローダーを作ったので、それを使ってバッチを回す。
まぁここも特に説明はないかな。

    # テストデータでの予測
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    trained_model.eval()
    trained_model.to(device)

    epoch_corrects = 0

    for batch in test_dl:
        inputs = batch.Text[0].to(device)
        labels = batch.Label.to(device)

        with torch.set_grad_enabled(False):
            outputs = trained_model(inputs, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False)

            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            epoch_corrects += torch.sum(preds == labels.data)

    epoch_acc = epoch_corrects.double() / len(test_dl.dataset)

    print('テストデータ:{}個での正解率:{:.4}'.format(len(test_dl.dataset), epoch_acc))

はい、以上でBERT終わり!IMDbのデータを読み込むところが20分くらいかかって、その後でエラーで落ちるとまた20分待って、、、ってのを繰り返して思った以上に時間がかかった。
次どーーしよ?レコメンドかきてぇ〜〜〜〜
日本語にBERTを使うのやるかな。

コード全体

def main():
    # config setting
    max_length = 256
    batch_size = 32
    num_epochs = 3
    data_path = os.path.join('/path/to/data_directory')

    # Pytorchでテキストを扱うためのデータセットの設定
    TEXT = torchtext.data.Field(sequential=True,                        # データの長さが可変である時True
                                tokenize=tokenizer_with_preprocessing,  # 文章を読み込んだ時に処理する前処理関数の指定
                                use_vocab=True,                         # 単語をボキャブラリーに追加するか
                                lower=True,                             # アルファベットを小文字に変換するか
                                include_lengths=True,                   # 文章の単語数のデータを保持するか
                                batch_first=True,                       # バッチサイズをテンソルの最初の次元で扱う
                                fix_length=max_length,                  # 指定した数になるようにPADDINGする
                                init_token='[CLS]',                     # 文頭
                                eos_token='[SEP]',                      # 文末
                                pad_token='[PAD]',                      # padding
                                unk_token='[UNK]')                      # 未知語
    LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

    bert_vocab_file = os.path.join(data_path, 'vocab', 'bert-base-uncased-vocab.txt')
    bert_tokenizer = BertTokenizer(vocab_file=bert_vocab_file, do_lower_case=True)

    train_val_ds, test_ds = torchtext.data.TabularDataset.splits(
        path=os.path.join(data_path, 'data'), train='IMDb_train.tsv', test='IMDb_test.tsv', format='tsv',
        fields=[('Text', TEXT), ('Label', LABEL)])

    train_ds, val_ds = train_val_ds.split(split_ratio=0.8, random_state=random.seed(1234))

    # データローダーの用意
    train_dl = torchtext.data.Iterator(train_ds, batch_size=batch_size, train=True)
    val_dl = torchtext.data.Iterator(val_ds, batch_size=batch_size, train=False, sort=False)
    test_dl = torchtext.data.Iterator(test_ds, batch_size=batch_size, train=False, sort=False)

    dataloaders_dict = {'train': train_dl, 'val': val_dl}

    # TEXTでvocab関数が利用できるようにbuildをするが、今回利用したいのはプレトレインのvocabなのでbuildするが、プレトレインで再度上書きする。ボキャブラリーを使う用意
    vocab_vert, ids_to_tokens_bert = load_vocab(vocab_file=os.path.join(data_path, 'vocab', 'bert-base-uncased-vocab.txt'))

    TEXT.build_vocab(train_ds, min_freq=1)
    TEXT.vocab.stoi = vocab_vert

    # 動作確認
    batch = next(iter(val_dl))
    print(batch.Text)
    print(batch.Label)
    text_minibatch_1 = (batch.Text[0][1]).numpy()
    text = bert_tokenizer.convert_ids_to_tokens(text_minibatch_1)
    print(text)

    # BERTのコンフィグファイルからモデルを作成
    conf_file = os.path.join(data_path, 'weights', 'bert_config.json')
    json_file = open(conf_file, 'r')
    json_object = json.load(json_file)

    config = AttrDict(json_object)
    bert_net = BertModel(config)

    # プレトレインのロード
    bert_net = set_learned_params(bert_net, weights_path=os.path.join(data_path, 'weights', 'pytorch_model.bin'))

    # IMDb向けに最後に全結合層を追加したネットワークを作成
    net = BertForIMBd(bert_net)
    net.train()
    print('ネットワーク設定完了')

    # IMBd向けのファインチューニングで最終層以外を学習しないようにする
    for name, param in net.named_parameters():
        param.requires_grad = False
    for name, param in net.bert.encoder.layer[-1].named_parameters():
        param.requires_grad = True
    for name, param in net.cls.named_parameters():
        param.requires_grad = True

    # 最適化関数の定義
    optimizer = optim.Adam([
        {'params': net.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
        {'params': net.cls.parameters(), 'lr': 5e-5}
    ], betas=(0.9, 0.999))

    # 損失関数の定義
    criterion = nn.CrossEntropyLoss()

    # 学習
    trained_model = train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

    # 学習済みモデルの保存
    torch.save(trained_model.state_dict(), 'models.pth')

    # テストデータでの予測
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    trained_model.eval()
    trained_model.to(device)

    epoch_corrects = 0

    for batch in test_dl:
        inputs = batch.Text[0].to(device)
        labels = batch.Label.to(device)

        with torch.set_grad_enabled(False):
            outputs = trained_model(inputs, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False, attention_show_flg=False)

            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            epoch_corrects += torch.sum(preds == labels.data)

    epoch_acc = epoch_corrects.double() / len(test_dl.dataset)

    print('テストデータ:{}個での正解率:{:.4}'.format(len(test_dl.dataset), epoch_acc))