PytorchのEmbeddingメモ

いつもtorch.nn.Embeddingの意味合いを忘れるのでメモ。

import torch
from torch import nn

embed = nn.Embedding(num_embeddings=4, embedding_dim=10, padding_idx=0)  # num_embeddingsが種類の数、embedding_dimはベクトル表現の次元数、padding_idxがpaddingの単語のindex番号

a(torch.Tensor([0]).type(torch.long))
>tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<EmbeddingBackward>)

a(torch.Tensor([1]).type(torch.long))
>tensor([[ 0.2685,  1.0670, -1.3489, -0.9967,  1.3931, -0.2417,  0.4070, -0.3864,
>         -0.2082,  1.7413]], grad_fn=<EmbeddingBackward>)

a(torch.Tensor([2]).type(torch.long))
> tensor([[-0.6267,  0.0558,  0.0259,  0.6943, -0.2168, -1.1249, -0.9203,  0.6561,
>          1.4061,  0.8107]], grad_fn=<EmbeddingBackward>)

a(torch.Tensor([3]).type(torch.long))
> tensor([[-0.5021, -1.2607,  0.2670, -0.9847, -1.8066,  0.1945, -0.4929, -1.5729,
>          -0.1613, -0.5040]], grad_fn=<EmbeddingBackward>)

a(torch.Tensor([4]).type(torch.long))
> Traceback (most recent call last):
>   File "<stdin>", line 1, in <module>
>   File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
>     result = self.forward(*input, **kwargs)
>   File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 114, in forward
>     self.norm_type, self.scale_grad_by_freq, self.sparse)
>   File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/functional.py", line 1484, in embedding
>     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
> RuntimeError: index out of range: Tried to access index 4 out of table with 3 rows. at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

4種類の単語をベクトル化したいという時に利用。ベクトルの数値はランダム。
Embeddingで定義したnum_embeddingsより大きいindex番号を指定すると、定義してないよーって怒られる。

ちなみに、word2vecなどの学習済みのモデルを用いてindexからベクトル表現を取って来る場合は以下。

        self.embeddings = nn.Embedding.from_pretrained(embeddings=text_embedding_vectors, freeze=True)  # freeze: バックプロップでの更新なし

text_embedding_vectorsは、以下のように作られるもの。

import torchtext
from torchtext.vocab import Vectors

def main():
    TEXT = torchtext.data.Field(sequential=True,  # データの長さが可変ならTrue
                                tokenize=tokenizer_with_preprocessing,  # 前処理用の関数
                                use_vocab=True,  # ボキャブラリーを作成するか
                                lower=True,  # アルファベットを小文字に変換
                                include_lengths=True,  # 文章の単語数の情報の保持をするか
                                batch_first=True,  # バッチサイズの次元を最初に
                                fix_length=max_length,  # この文字数以上になったらカット、以下なら<pad>を埋める
                                init_token='<cls>',  # 文章の最初に<cls>という文字列をセット
                                eos_token='<eos>')  # 文章の最後に<eos>という文字列をセット

    TEXT.build_vocab(train_ds, vectors=Vectors(name=os.path.join(data_path, 'wiki-news-300d-1M.vec')), min_freq=10)