書いてる理由
- NLPやるぞー
- レビューがポジティブかネガティブかを判断する
- ネットワークは書いたから次は学習と予測
参考
pytorchによる発展ディープラーニング
Attention is All You Need
概要
これまでIMDbのデータを扱う方法と、テキストデータを使った分類のネットワークを作っていったので、今回は実際の学習と予測を実施。
コード
詳細
データの用意、学習のハイパーパラメータ設定、ネットワークの設定
まずIMDbの学習のためのデータローダーと、学習データで作ったボキャブラリーを扱うための情報をget_IMDb_Dataloader_and_text
で取得する。
次にハイパーパラメータとしてmax_seq_len
が1文を扱う単語数の最大値、output_cls_num
が予測したいクラス数、learning_rate
は学習率、num_epochs
は何回学習を回すかをそれぞれセット。
モデルは前回作ったTransformerClassification
で定義して、この中の全結合層のパラメータを初期化する関数を作ってそれを通して設定完了。
def main(): # config setting train_dl, val_dl, test_dl, TEXT = get_IMDb_Dataloader_and_text(256, 24) model_dim = TEXT.vocab.vectors.shape[1] # ベクトルの次元数 shape[0]には単語数が入ってる max_seq_len = 256 output_cls_num = 2 learning_rate = 2e-5 num_epochs = 10 dataloaders_dict = {'train': train_dl, 'val': val_dl} model = TransformerClassification(text_embedding_vectors=TEXT.vocab.vectors, model_dim=model_dim, max_seq_len=max_seq_len, output_dim=output_cls_num) def weights_init(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0.0) model.train() model.net3_1.apply(weights_init) model.net3_2.apply(weights_init) print('ネットワークの設定完了')
損失関数と最適化関数の定義
損失関数は、クラス分類問題なのでCrossEntropyで定義。
最適化関数はAdamで今回は定義。
# 損失関数の定義 criterion = nn.CrossEntropyLoss() # 最適化関数の定義 optimizer = optim.Adam(model.parameters(), lr=learning_rate)
学習
学習は今回は特に難しいところはなく、バッチごとにデータを取ってきて予測して正解ラベルと差分を取ってパラメータ更新を繰り返す。
唯一癖のあるところとしては、paddingをmaskするために、input_mask = (inputs != input_pad)
を作っている。これは前回やった通り、ネットワークの中で0に近い値になって出力されて返ってくる。
def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=10): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) torch.backends.cudnn.benchmark = True for epoch in range(num_epochs): for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() epoch_loss = 0.0 epoch_corrects = 0 for batch in dataloaders_dict[phase]: inputs = batch.Text[0].to(device) labels = batch.Label.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): # maskの作成 input_pad = 1 input_mask = (inputs != input_pad) outputs, _, _ = model(inputs, input_mask) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # ラベルの予測 if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('Epoch {}/{} | {} | Loss: {} Acc: {}'.format(epoch + 1, num_epochs, phase, epoch_loss, epoch_acc)) torch.save(model.state_dict(), 'models.pth') return model # 学習 trained_model = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)
テストデータの予測と可視化
テスト時なので、model.eval()
をして、検証モードにする。
このコードをそのまま実行すると、以下のような出力が得られて、テストデータでの精度が確認できる。
テストデータ25000個での正解率:0.84752
# テストデータで検証 trained_model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') trained_model.to(device) epoch_corrects = 0 for i, batch in enumerate(test_dl): inputs = batch.Text[0].to(device) labels = batch.Label.to(device) with torch.set_grad_enabled(False): input_pad = 1 input_mask = (inputs != input_pad) outputs, normlized_weights_1, normlized_weights_2 = trained_model(inputs, input_mask) _, preds = torch.max(outputs, 1) index = 3 html_output = mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT) with open(str(i) + '.html', 'w') as outf: outf.write(html_output) epoch_corrects += torch.sum(preds == labels.data) epoch_acc = epoch_corrects.double() / len(test_dl.dataset) print('テストデータ{}個での正解率:{}'.format(len(test_dl.dataset), epoch_acc))
mk_html
は以下のコードとなっている。
Attentionの結果、softmaxで規格化されたデータがnormlized_weights_nの中に入っていて、これが0~1の値なのだが、1に近いほど影響を大きく与えている単語となり、これを視覚化するのがこの関数。
実際の結果が、以下のようになる。
この結果の場合、正解がPositiveで予測結果もPositiveでうまく予測できている。
赤色が強い単語が結果の判断に大きな影響を与えていて、例えば1段目にfunという単語が赤が強くなっていて、ポジティブなレビューという判定をしているのでは的に確認することができる。
def highlight(word, attn): """Attentionの値が大きいと文字の背景が濃い赤になるhtmlを出力させる関数""" html_color = '#%02X%02X%02X' % ( 255, int(255*(1 - attn)), int(255*(1 - attn))) return '<span style="background-color: {}"> {}</span>'.format(html_color, word) def mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT): """HTMLデータを作成する""" # indexの結果を抽出 sentence = batch.Text[0][index] # 文章 label = batch.Label[index] # ラベル pred = preds[index] # 予測 # indexのAttentionを抽出と規格化 attens1 = normlized_weights_1[index, 0, :] # 0番目の<cls>のAttention attens1 /= attens1.max() attens2 = normlized_weights_2[index, 0, :] # 0番目の<cls>のAttention attens2 /= attens2.max() # ラベルと予測結果を文字に置き換え if label == 0: label_str = "Negative" else: label_str = "Positive" if pred == 0: pred_str = "Negative" else: pred_str = "Positive" # 表示用のHTMLを作成する html = '正解ラベル:{}<br>推論ラベル:{}<br><br>'.format(label_str, pred_str) # 1段目のAttention html += '[TransformerBlockの1段目のAttentionを可視化]<br>' for word, attn in zip(sentence, attens1): html += highlight(TEXT.vocab.itos[word], attn) html += "<br><br>" # 2段目のAttention html += '[TransformerBlockの2段目のAttentionを可視化]<br>' for word, attn in zip(sentence, attens2): html += highlight(TEXT.vocab.itos[word], attn) html += "<br><br>" return html
以上!
ネットワークのAttentionはマジでもうちょっと深く理解したい。。
次何しよっかなーー。BERTかSSDかitemのレコメンド。できたらレコメンドやりたいな。
コード全体
def highlight(word, attn): """Attentionの値が大きいと文字の背景が濃い赤になるhtmlを出力させる関数""" html_color = '#%02X%02X%02X' % ( 255, int(255*(1 - attn)), int(255*(1 - attn))) return '<span style="background-color: {}"> {}</span>'.format(html_color, word) def mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT): """HTMLデータを作成する""" # indexの結果を抽出 sentence = batch.Text[0][index] # 文章 label = batch.Label[index] # ラベル pred = preds[index] # 予測 # indexのAttentionを抽出と規格化 attens1 = normlized_weights_1[index, 0, :] # 0番目の<cls>のAttention attens1 /= attens1.max() attens2 = normlized_weights_2[index, 0, :] # 0番目の<cls>のAttention attens2 /= attens2.max() # ラベルと予測結果を文字に置き換え if label == 0: label_str = "Negative" else: label_str = "Positive" if pred == 0: pred_str = "Negative" else: pred_str = "Positive" # 表示用のHTMLを作成する html = '正解ラベル:{}<br>推論ラベル:{}<br><br>'.format(label_str, pred_str) # 1段目のAttention html += '[TransformerBlockの1段目のAttentionを可視化]<br>' for word, attn in zip(sentence, attens1): html += highlight(TEXT.vocab.itos[word], attn) html += "<br><br>" # 2段目のAttention html += '[TransformerBlockの2段目のAttentionを可視化]<br>' for word, attn in zip(sentence, attens2): html += highlight(TEXT.vocab.itos[word], attn) html += "<br><br>" return html def train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=10): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) torch.backends.cudnn.benchmark = True for epoch in range(num_epochs): for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() epoch_loss = 0.0 epoch_corrects = 0 for batch in dataloaders_dict[phase]: inputs = batch.Text[0].to(device) labels = batch.Label.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): # maskの作成 input_pad = 1 input_mask = (inputs != input_pad) outputs, _, _ = model(inputs, input_mask) loss = criterion(outputs, labels) _, preds = torch.max(outputs, 1) # ラベルの予測 if phase == 'train': loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_corrects += torch.sum(preds == labels.data) epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset) epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset) print('Epoch {}/{} | {} | Loss: {} Acc: {}'.format(epoch + 1, num_epochs, phase, epoch_loss, epoch_acc)) torch.save(model.state_dict(), 'models.pth') return model def main(): # config setting train_dl, val_dl, test_dl, TEXT = get_IMDb_Dataloader_and_text(256, 24) model_dim = TEXT.vocab.vectors.shape[1] # ベクトルの次元数 shape[0]には単語数が入ってる max_seq_len = 256 output_cls_num = 2 learning_rate = 2e-5 num_epochs = 10 dataloaders_dict = {'train': train_dl, 'val': val_dl} model = TransformerClassification(text_embedding_vectors=TEXT.vocab.vectors, model_dim=model_dim, max_seq_len=max_seq_len, output_dim=output_cls_num) def weights_init(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0.0) model.train() model.net3_1.apply(weights_init) model.net3_2.apply(weights_init) print('ネットワークの設定完了') # 損失関数の定義 criterion = nn.CrossEntropyLoss() # 最適化関数の定義 optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 学習 trained_model = train_model(model, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs) # テストデータで検証 trained_model.eval() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') trained_model.to(device) epoch_corrects = 0 for i, batch in enumerate(test_dl): inputs = batch.Text[0].to(device) labels = batch.Label.to(device) with torch.set_grad_enabled(False): input_pad = 1 input_mask = (inputs != input_pad) outputs, normlized_weights_1, normlized_weights_2 = trained_model(inputs, input_mask) _, preds = torch.max(outputs, 1) index = 3 html_output = mk_html(index, batch, preds, normlized_weights_1, normlized_weights_2, TEXT) with open(str(i) + '.html', 'w') as outf: outf.write(html_output) epoch_corrects += torch.sum(preds == labels.data) epoch_acc = epoch_corrects.double() / len(test_dl.dataset) print('テストデータ{}個での正解率:{}'.format(len(test_dl.dataset), epoch_acc)) if __name__ == '__main__': main()