いつもtorch.nn.Embeddingの意味合いを忘れるのでメモ。
import torch from torch import nn embed = nn.Embedding(num_embeddings=4, embedding_dim=10, padding_idx=0) # num_embeddingsが種類の数、embedding_dimはベクトル表現の次元数、padding_idxがpaddingの単語のindex番号 a(torch.Tensor([0]).type(torch.long)) >tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<EmbeddingBackward>) a(torch.Tensor([1]).type(torch.long)) >tensor([[ 0.2685, 1.0670, -1.3489, -0.9967, 1.3931, -0.2417, 0.4070, -0.3864, > -0.2082, 1.7413]], grad_fn=<EmbeddingBackward>) a(torch.Tensor([2]).type(torch.long)) > tensor([[-0.6267, 0.0558, 0.0259, 0.6943, -0.2168, -1.1249, -0.9203, 0.6561, > 1.4061, 0.8107]], grad_fn=<EmbeddingBackward>) a(torch.Tensor([3]).type(torch.long)) > tensor([[-0.5021, -1.2607, 0.2670, -0.9847, -1.8066, 0.1945, -0.4929, -1.5729, > -0.1613, -0.5040]], grad_fn=<EmbeddingBackward>) a(torch.Tensor([4]).type(torch.long)) > Traceback (most recent call last): > File "<stdin>", line 1, in <module> > File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__ > result = self.forward(*input, **kwargs) > File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/modules/sparse.py", line 114, in forward > self.norm_type, self.scale_grad_by_freq, self.sparse) > File "/Users/shirai1/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/torch/nn/functional.py", line 1484, in embedding > return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) > RuntimeError: index out of range: Tried to access index 4 out of table with 3 rows. at ../aten/src/TH/generic/THTensorEvenMoreMath.cpp:418
4種類の単語をベクトル化したいという時に利用。ベクトルの数値はランダム。
Embeddingで定義したnum_embeddingsより大きいindex番号を指定すると、定義してないよーって怒られる。
ちなみに、word2vecなどの学習済みのモデルを用いてindexからベクトル表現を取って来る場合は以下。
self.embeddings = nn.Embedding.from_pretrained(embeddings=text_embedding_vectors, freeze=True) # freeze: バックプロップでの更新なし
text_embedding_vectorsは、以下のように作られるもの。
import torchtext from torchtext.vocab import Vectors def main(): TEXT = torchtext.data.Field(sequential=True, # データの長さが可変ならTrue tokenize=tokenizer_with_preprocessing, # 前処理用の関数 use_vocab=True, # ボキャブラリーを作成するか lower=True, # アルファベットを小文字に変換 include_lengths=True, # 文章の単語数の情報の保持をするか batch_first=True, # バッチサイズの次元を最初に fix_length=max_length, # この文字数以上になったらカット、以下なら<pad>を埋める init_token='<cls>', # 文章の最初に<cls>という文字列をセット eos_token='<eos>') # 文章の最後に<eos>という文字列をセット TEXT.build_vocab(train_ds, vectors=Vectors(name=os.path.join(data_path, 'wiki-news-300d-1M.vec')), min_freq=10)