Pytorchでtransformer(その3 学習とテストデータを使った予測)

書いてる理由

  • NLPやるぞー
  • レビューがポジティブかネガティブかを判断する
  • ネットワークは書いたから次は学習と予測

参考

pytorchによる発展ディープラーニング
Attention is All You Need

概要

これまでIMDbのデータを扱う方法と、テキストデータを使った分類のネットワークを作っていったので、今回は実際の学習と予測を実施。

コード

github.com

詳細

データの用意、学習のハイパーパラメータ設定、ネットワークの設定

まずIMDbの学習のためのデータローダーと、学習データで作ったボキャブラリーを扱うための情報をget_IMDb_Dataloader_and_textで取得する。
次にハイパーパラメータとしてmax_seq_lenが1文を扱う単語数の最大値、output_cls_num が予測したいクラス数、learning_rateは学習率、num_epochsは何回学習を回すかをそれぞれセット。
モデルは前回作ったTransformerClassificationで定義して、この中の全結合層のパラメータを初期化する関数を作ってそれを通して設定完了。

def main():
    # config setting
    train_dl, val_dl, test_dl, TEXT = get_IMDb_Dataloader_and_text(256, 24)
    model_dim = TEXT.vocab.vectors.shape[1]  # ベクトルの次元数 shape[0]には単語数が入ってる
    max_seq_len = 256
    output_cls_num = 2
    learning_rate = 2e-5
    num_epochs = 10

    dataloaders_dict = {'train': train_dl, 'val': val_dl}

    model = TransformerClassification(text_embedding_vectors=TEXT.vocab.vectors, model_dim=model_dim,
                                      max_seq_len=max_seq_len, output_dim=output_cls_num)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)

    model.train()

    model.net3_1.apply(weights_init)
    model.net3_2.apply(weights_init)

    print('ネットワークの設定完了')

損失関数と最適化関数の定義

損失関数は、クラス分類問題なのでCrossEntropyで定義。
最適化関数はAdamで今回は定義。

    # 損失関数の定義
    criterion = nn.CrossEntropyLoss()

    # 最適化関数の定義
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

学習

学習は今回は特に難しいところはなく、バッチごとにデータを取ってきて予測して正解ラベルと差分を取ってパラメータ更新を繰り返す。
唯一癖のあるところとしては、paddingをmaskするために、input_mask = (inputs != input_pad)を作っている。これは前回やった通り、ネットワークの中で0に近い値になって出力されて返ってくる。

def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=10):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    torch.backends.cudnn.benchmark = True

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            epoch_loss = 0.0
            epoch_corrects = 0

            for batch in dataloaders_dict[phase]:
                inputs = batch.Text[0].to(device)
                labels = batch.Label.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # maskの作成
                    input_pad = 1
                    input_mask = (inputs != input_pad)

                    outputs, _, _ = model(inputs, input_mask)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)  # ラベルの予測

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)

            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

            print('Epoch {}/{} | {} | Loss: {} Acc: {}'.format(epoch + 1, num_epochs, phase, epoch_loss, epoch_acc))

    torch.save(model.state_dict(), 'models.pth')
    return model


    # 学習
    trained_model = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

テストデータの予測と可視化

テスト時なので、model.eval()をして、検証モードにする。
このコードをそのまま実行すると、以下のような出力が得られて、テストデータでの精度が確認できる。

テストデータ25000個での正解率:0.84752

    # テストデータで検証
    trained_model.eval()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    trained_model.to(device)

    epoch_corrects = 0
    for i, batch in enumerate(test_dl):
        inputs = batch.Text[0].to(device)
        labels = batch.Label.to(device)

        with torch.set_grad_enabled(False):
            input_pad = 1
            input_mask = (inputs != input_pad)

            outputs, normlized_weights_1, normlized_weights_2 = trained_model(inputs, input_mask)
            _, preds = torch.max(outputs, 1)
            index = 3
            html_output = mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT)
            with open(str(i) + '.html', 'w') as outf:
                outf.write(html_output)

            epoch_corrects += torch.sum(preds == labels.data)

    epoch_acc = epoch_corrects.double() / len(test_dl.dataset)

    print('テストデータ{}個での正解率:{}'.format(len(test_dl.dataset), epoch_acc))

mk_htmlは以下のコードとなっている。
Attentionの結果、softmaxで規格化されたデータがnormlized_weights_nの中に入っていて、これが0~1の値なのだが、1に近いほど影響を大きく与えている単語となり、これを視覚化するのがこの関数。
実際の結果が、以下のようになる。

f:id:raishi12:20200329213626p:plain
予測結果の可視化

この結果の場合、正解がPositiveで予測結果もPositiveでうまく予測できている。
赤色が強い単語が結果の判断に大きな影響を与えていて、例えば1段目にfunという単語が赤が強くなっていて、ポジティブなレビューという判定をしているのでは的に確認することができる。

def highlight(word, attn):
    """Attentionの値が大きいと文字の背景が濃い赤になるhtmlを出力させる関数"""

    html_color = '#%02X%02X%02X' % (
        255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}"> {}</span>'.format(html_color, word)


def mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT):
    """HTMLデータを作成する"""

    # indexの結果を抽出
    sentence = batch.Text[0][index]  # 文章
    label = batch.Label[index]  # ラベル
    pred = preds[index]  # 予測

    # indexのAttentionを抽出と規格化
    attens1 = normlized_weights_1[index, 0, :]  # 0番目の<cls>のAttention
    attens1 /= attens1.max()

    attens2 = normlized_weights_2[index, 0, :]  # 0番目の<cls>のAttention
    attens2 /= attens2.max()

    # ラベルと予測結果を文字に置き換え
    if label == 0:
        label_str = "Negative"
    else:
        label_str = "Positive"

    if pred == 0:
        pred_str = "Negative"
    else:
        pred_str = "Positive"

    # 表示用のHTMLを作成する
    html = '正解ラベル:{}<br>推論ラベル:{}<br><br>'.format(label_str, pred_str)

    # 1段目のAttention
    html += '[TransformerBlockの1段目のAttentionを可視化]<br>'
    for word, attn in zip(sentence, attens1):
        html += highlight(TEXT.vocab.itos[word], attn)
    html += "<br><br>"

    # 2段目のAttention
    html += '[TransformerBlockの2段目のAttentionを可視化]<br>'
    for word, attn in zip(sentence, attens2):
        html += highlight(TEXT.vocab.itos[word], attn)

    html += "<br><br>"

    return html

以上!
ネットワークのAttentionはマジでもうちょっと深く理解したい。。
次何しよっかなーー。BERTかSSDかitemのレコメンド。できたらレコメンドやりたいな。

コード全体

def highlight(word, attn):
    """Attentionの値が大きいと文字の背景が濃い赤になるhtmlを出力させる関数"""

    html_color = '#%02X%02X%02X' % (
        255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}"> {}</span>'.format(html_color, word)


def mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT):
    """HTMLデータを作成する"""

    # indexの結果を抽出
    sentence = batch.Text[0][index]  # 文章
    label = batch.Label[index]  # ラベル
    pred = preds[index]  # 予測

    # indexのAttentionを抽出と規格化
    attens1 = normlized_weights_1[index, 0, :]  # 0番目の<cls>のAttention
    attens1 /= attens1.max()

    attens2 = normlized_weights_2[index, 0, :]  # 0番目の<cls>のAttention
    attens2 /= attens2.max()

    # ラベルと予測結果を文字に置き換え
    if label == 0:
        label_str = "Negative"
    else:
        label_str = "Positive"

    if pred == 0:
        pred_str = "Negative"
    else:
        pred_str = "Positive"

    # 表示用のHTMLを作成する
    html = '正解ラベル:{}<br>推論ラベル:{}<br><br>'.format(label_str, pred_str)

    # 1段目のAttention
    html += '[TransformerBlockの1段目のAttentionを可視化]<br>'
    for word, attn in zip(sentence, attens1):
        html += highlight(TEXT.vocab.itos[word], attn)
    html += "<br><br>"

    # 2段目のAttention
    html += '[TransformerBlockの2段目のAttentionを可視化]<br>'
    for word, attn in zip(sentence, attens2):
        html += highlight(TEXT.vocab.itos[word], attn)

    html += "<br><br>"

    return html


def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=10):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    torch.backends.cudnn.benchmark = True

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            epoch_loss = 0.0
            epoch_corrects = 0

            for batch in dataloaders_dict[phase]:
                inputs = batch.Text[0].to(device)
                labels = batch.Label.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # maskの作成
                    input_pad = 1
                    input_mask = (inputs != input_pad)

                    outputs, _, _ = model(inputs, input_mask)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)  # ラベルの予測

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)

            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

            print('Epoch {}/{} | {} | Loss: {} Acc: {}'.format(epoch + 1, num_epochs, phase, epoch_loss, epoch_acc))

    torch.save(model.state_dict(), 'models.pth')
    return model


def main():
    # config setting
    train_dl, val_dl, test_dl, TEXT = get_IMDb_Dataloader_and_text(256, 24)
    model_dim = TEXT.vocab.vectors.shape[1]  # ベクトルの次元数 shape[0]には単語数が入ってる
    max_seq_len = 256
    output_cls_num = 2
    learning_rate = 2e-5
    num_epochs = 10

    dataloaders_dict = {'train': train_dl, 'val': val_dl}

    model = TransformerClassification(text_embedding_vectors=TEXT.vocab.vectors, model_dim=model_dim,
                                      max_seq_len=max_seq_len, output_dim=output_cls_num)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0.0)

    model.train()

    model.net3_1.apply(weights_init)
    model.net3_2.apply(weights_init)

    print('ネットワークの設定完了')

    # 損失関数の定義
    criterion = nn.CrossEntropyLoss()

    # 最適化関数の定義
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # 学習
    trained_model = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)

    # テストデータで検証
    trained_model.eval()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    trained_model.to(device)

    epoch_corrects = 0
    for i, batch in enumerate(test_dl):
        inputs = batch.Text[0].to(device)
        labels = batch.Label.to(device)

        with torch.set_grad_enabled(False):
            input_pad = 1
            input_mask = (inputs != input_pad)

            outputs, normlized_weights_1, normlized_weights_2 = trained_model(inputs, input_mask)
            _, preds = torch.max(outputs, 1)
            index = 3
            html_output = mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT)
            with open(str(i) + '.html', 'w') as outf:
                outf.write(html_output)

            epoch_corrects += torch.sum(preds == labels.data)

    epoch_acc = epoch_corrects.double() / len(test_dl.dataset)

    print('テストデータ{}個での正解率:{}'.format(len(test_dl.dataset), epoch_acc))


if __name__ == '__main__':
    main()