【論文解説】Embeddingとは?BERT論文からLLMが単語や文章をベクトルで表す仕組みを解説

文章がTokenization、token ID、Embedding、Transformer Encoderを通って文脈化Embeddingになる流れ

Embeddingとは?BERT論文からLLMが単語や文章をベクトルで表す仕組みを解説

3文要約

文章がTokenization、token ID、Embedding、Transformer Encoderを通って文脈化Embeddingになる流れ

Embedding(埋め込み)は、token IDを固定次元のベクトルへ変換し、単語やサブワードを計算可能な表現にする層です。

BERT論文では、WordPiece tokenのEmbeddingに、文A/Bを区別するsegment embeddingと、語順を表すposition embeddingを足し合わせ、Transformer Encoderへ入力します。

重要なのは、BERTの出力Embeddingが単なる単語辞書ではなく、左右の文脈を反映したcontextualized representation(文脈化表現)になる点です。

論文情報

項目 内容
論文タイトル BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
著者 Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
初版公開日 2018年10月11日
改訂版 2019年5月24日 v2
分野 Computation and Language
arXiv BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
DOI 10.48550/arXiv.1810.04805

Embeddingとは何か

token IDがEmbedding行列の対応する行から連続的な数値列のベクトルへ変換される具体例

Embeddingは、離散的なIDを連続値のベクトルへ変換する仕組みです。

日本語では「埋め込み」と訳されます。

これは、単語やtokenを、そのまま文字として扱うのではなく、意味や使われ方を表しやすい連続的なベクトル空間の中へ配置する、というイメージです。

たとえば、token ID 1176 を [0.12, -0.08, 0.44, ...] のような数値列へ変換します。

このように、離散的なIDをニューラルネットワークが計算できる連続的なベクトルへ写すことを、Embedding、つまり埋め込みと呼びます。

前回のTokenization編では、文章がtoken列とtoken ID列へ変換されるところまで見ました。

しかし、Transformer(Attentionを中心に系列を処理するニューラルネットワーク)は、101 や 7592 のような整数IDを、そのまま意味として理解するわけではありません。

モデル内部では、各IDに対応するベクトルをEmbedding行列から取り出します。

語彙サイズを $V$、Embedding次元を $H$ とすると、Embedding行列は次の形になります。

ERV×H

 

token ID $i$ に対するEmbeddingは、行列 $E$ の $i$ 番目の行です。

xi=E[i]

 

この $x_i$ が、Transformerへ渡される最初の表現になります。
表現 役割
文字列 I like transformers 人間が読む入力
token [I, like, transformers] tokenizerが作る単位
token ID [146, 1176, 19081] 語彙表上の整数
Embedding [[...], [...], [...]] モデルが計算するベクトル

Embeddingを使うことで、似た使われ方をするtokenが、学習を通じて近い方向のベクトルになりやすくなります。

ただし、Embeddingそのものに最初から意味が入っているわけではありません。

大量のテキストで学習する過程で、意味や構文に関係するパターンが重みとして獲得されます。

BERT論文での入力表現

BERTの入力表現がtoken embedding、segment embedding、position embeddingの和で作られることを示す図

BERTの入力表現は、単一のtoken embeddingだけではありません。

論文では、各入力位置の表現を次の3種類のEmbeddingの和として作ります。

hi(0)=etoken(ti)+esegment(si)+eposition(i)

ここで、$t_i$ はtoken ID、$s_i$ は文A/Bを示すsegment ID、$i$ は系列内の位置です。

Embedding 何を表すか BERTでの役割
token embedding token IDに対応する意味の初期表現 WordPiece tokenをベクトル化する
segment embedding tokenが文Aか文Bか 文ペアタスクやNSPで入力の所属を区別する
position embedding tokenの位置 Transformerに語順情報を与える

BERTはTransformer Encoderを使います。

Self-Attention(系列中のtoken同士の関係を重み付きで見る仕組み)は、入力の順番をそのまま保持する構造ではありません。

そのため、position embeddingを足して、どのtokenが何番目にあるかを伝えます。

WordPieceと特殊token

WordPieceで文がサブワードと特殊tokenを含むBERT入力列に変換される流れ

BERTは、WordPiece embeddingを使います。

WordPiece(単語をサブワードへ分割する方式)は、未知語を減らしながら語彙サイズを抑えるためのtokenization手法です。

論文では、30,000 tokenの語彙が使われています。

入力には、通常の単語やサブワードだけでなく、特殊tokenも含まれます。

token 役割
[CLS] 系列の先頭に置かれ、分類タスクの代表表現として使われる
[SEP] 文の区切りや文ペアの境界を示す
[MASK] Masked Language Modelingで予測対象を隠す

たとえば、文ペアを入力する場合は、次のような形になります。

[CLS] my dog is cute [SEP] he likes playing [SEP]

このとき、前半の文にはsegment A、後半の文にはsegment BのEmbeddingが足されます。

tokens:   [CLS] my dog is cute [SEP] he likes playing [SEP]
segment:    A   A  A   A   A    A    B    B      B     B
position:   0   1  2   3   4    5    6    7      8     9

この3つの情報を足し合わせることで、BERTは「どのtokenか」「どちらの文か」「何番目か」を同時に受け取ります。

静的Embeddingと文脈化Embeddingの違い

静的EmbeddingとBERTの文脈化Embeddingの違いをbankの例で比較した図

Embeddingを理解するときに重要なのが、静的Embeddingと文脈化Embeddingの違いです。

静的Embeddingでは、同じ単語は基本的に同じベクトルになります。

一方、BERTのようなモデルでは、入力EmbeddingがTransformer層を通過したあと、周囲の文脈に応じた表現へ変わります。

観点 静的Embedding BERTの文脈化Embedding
同じ単語の表現 原則として同じ 文脈によって変わる
word2vec、GloVe BERTの各層出力
多義語 1つのベクトルに混ざりやすい 文脈から意味を分けやすい
代表的な使い方 類似語検索、特徴量 分類、抽出、検索、再ランキング

たとえば、bank という英単語は「銀行」と「川岸」の意味を持ちます。

静的Embeddingでは、この2つの意味が1つのベクトルに混ざりやすくなります。

しかし、BERTでは次の2文の bank が異なる文脈で処理されます。

I deposited money at the bank.
I sat on the bank of the river.

入力時点のtoken embeddingは同じでも、Self-Attentionによって周囲の語と相互作用した後の出力表現は変わります。

これが、BERTのEmbeddingを「文脈化表現」として扱える理由です。

BERTはなぜ双方向なのか

BERTが左右の文脈を使ってMASKされたtokenを予測する双方向表現のイメージ

BERTの大きな特徴は、deep bidirectional representations(深い双方向表現)を事前学習する点です。

GPT系の自己回帰モデルは、基本的に左から右へ次tokenを予測します。

一方、BERTはTransformer Encoderを使い、各層で左右両方の文脈を参照します。

観点 左から右の言語モデル BERT
参照できる文脈 予測位置より左側 左右両方
主な目的 次token予測 Masked Language Modeling
得意な用途 生成 理解、分類、抽出
出力の使い方 次のtoken分布 token/文の表現

ただし、双方向に全文を見せたまま「次の単語を当てる」学習をすると、答えのtoken自身を見てしまう問題があります。

そこでBERTは、入力の一部を隠して、その元tokenを当てるMasked Language Modelingを使います。

Masked Language Modeling

BERTのMasked Language Modelingで入力tokenの一部を隠し、左右の文脈から元tokenを予測する流れ

Masked Language Modeling(入力の一部を隠し、元のtokenを予測する事前学習タスク)は、BERTの中心的な学習目的です。

論文では、入力tokenの15%を予測対象にします。

ただし、選ばれたtokenをすべて [MASK] に置き換えるわけではありません。

選ばれたtokenの処理 割合 狙い
[MASK] に置換 80% 予測対象を明示的に隠す
ランダムtokenに置換 10% 入力のノイズに頑健にする
そのまま残す 10% [MASK] がFine-tuning時に出ない差を緩和する

たとえば、次の文があるとします。

the man went to the store

store が予測対象になった場合、入力は次のようになります。

the man went to the [MASK]

BERTは左右の文脈を使って、隠されたtokenが store である確率を高くするように学習します。

損失は、予測対象になった位置だけで計算します。

LMLM=iMlogp(tixM)

$M$ はmask対象位置の集合、$x_{\setminus M}$ は一部がmaskされた入力です。

この学習により、各tokenの表現は周囲の文脈を使って意味を補う方向へ調整されます。

Next Sentence Prediction

BERTのNext Sentence Predictionで文Aと文Bが連続するかをCLS表現から判定する流れ

BERT論文では、Masked Language Modelingに加えて、Next Sentence Predictionも使われています。

Next Sentence Prediction(2つの文が連続しているかを判定する事前学習タスク)は、文ペアの関係を学ぶための目的です。

入力は文Aと文Bのペアです。

50%は実際に連続する文、50%はランダムに選んだ文Bです。

モデルは [CLS] の最終表現を使い、文Bが文Aの次に来るかを分類します。

LNSP=logp(yh[CLS])

学習目的 単位 何を学ばせたいか
MLM token 文脈から欠けた語を補う
NSP 文ペア 2文の関係を判定する

なお、後続研究ではNSPの必要性について再検討されています。

そのため、現在のLLM理解では「BERT論文の設計としてNSPが使われた」と押さえつつ、すべての後続モデルに必須とは考えない方が安全です。

事前学習済みEmbeddingをどう使うのか

事前学習済みBERTの文脈化Embeddingを分類、抽出、質問応答、検索へ転用する流れ

BERTの価値は、事前学習済みの表現を下流タスクへ転用できる点にあります。

論文では、タスクごとに大きな専用モデルを作るのではなく、BERT本体を初期化に使い、最後に小さな出力層を足してFine-tuning(事前学習済みモデルを特定タスクに追加学習すること)します。

タスク 使う表現
文分類 [CLS] の最終表現 感情分類、自然言語推論
token分類 各tokenの最終表現 固有表現抽出
質問応答 各tokenの開始/終了スコア SQuAD形式の抽出型QA
文検索 文全体の表現 類似文検索、RAGの検索器

注意点として、BERTの [CLS] 表現をそのまま汎用のsentence embeddingとして使うと、用途によっては期待ほど強くない場合があります。

検索や類似度計算では、Sentence-BERTのように文ペア類似度向けに調整されたモデルが使われることが多いです。

つまり、「BERTは文脈化表現を作る」ことと、「そのまま最高の検索Embeddingになる」ことは分けて考える必要があります。

実験結果の要約

BERT論文では、GLUE、MultiNLI、SQuADなど、複数の自然言語理解タスクで当時のstate-of-the-art(当時最高水準)を更新したと報告されています。

arXiv概要では、GLUE score 80.5%、MultiNLI accuracy 86.7%、SQuAD v1.1 Test F1 93.2、SQuAD v2.0 Test F1 83.1が示されています。

評価 論文で示された結果の意味
GLUE 文分類や含意認識など複数タスクの総合評価
MultiNLI 文ペアの自然言語推論性能
SQuAD v1.1 抽出型質問応答で答え範囲を当てる性能
SQuAD v2.0 答えが存在しない質問も含む質問応答性能

この結果は、BERTの文脈化Embeddingが単独の特徴量として便利というだけでなく、Fine-tuningによって幅広いNLPタスクへ適応できることを示しています。

ただし、2018年当時の比較であり、現在の生成LLMの性能を直接比較するものではありません。

BERTは、生成よりも理解・分類・抽出に強いEncoder型モデルとして位置づけると分かりやすいです。

よくある誤解

誤解 正確な見方
Embeddingは単語の意味そのものを保存した辞書である Embeddingは学習で得られるベクトル表現であり、意味は文脈とタスクを通じて現れる
token embeddingだけ見ればBERTの理解が分かる BERTで重要なのはTransformer層を通った後の文脈化表現
[CLS] は常に最高のsentence embeddingである 分類には便利だが、類似度検索では専用に学習されたモデルが有利なことが多い
position embeddingは補助的なので不要 Self-Attentionだけでは語順を直接区別しにくいため、位置情報が必要
BERTはGPTと同じように文章生成が得意 BERTはEncoder型で、主に理解・分類・抽出に向く

実装者視点で見るEmbedding

実装では、Embeddingは巨大な行列です。

語彙サイズ $V$ と隠れ次元 $H$ が大きくなるほど、Embedding層のパラメータ数も増えます。

parameters=V×H

BERT Baseのように、語彙サイズが約30,000、隠れ次元が768なら、token embeddingだけで約2,300万パラメータになります。

これはモデル全体の中でも無視できないサイズです。

設計項目 実装上の注意
語彙サイズ 大きいほどEmbedding行列が重くなる
位置長 最大系列長を超える入力は切り詰めや分割が必要
segment ID 単文タスクではすべて0にする実装が多い
padding attention maskと合わせて扱う必要がある
weight tying 出力層とEmbeddingを共有する設計もある

Embeddingの差し替えは、tokenizerの差し替えと強く結びつきます。

token IDとEmbedding行列の行が対応しているため、tokenizerだけ変えると、同じIDが別のtokenを指してしまう可能性があります。

Fine-tuning済みモデルでは、tokenizer、Embedding、出力層をセットで管理する必要があります。

実装例:BERT風の入力Embeddingを作る

以下は、BERT風にtoken embedding、segment embedding、position embeddingを足し合わせる最小例です。

実際のBERT実装では、LayerNorm(層ごとに値のスケールを整える処理)やdropout(過学習を抑えるために一部の値を落とす処理)も組み合わせます。

import logging

import torch
from torch import Tensor, nn

logger = logging.getLogger(__name__)


class BertStyleInputEmbedding(nn.Module):
    """Build BERT-style input embeddings from token, segment, and position IDs.

    Args:
        vocab_size: Number of tokens in the tokenizer vocabulary.
        hidden_size: Embedding dimension used by the Transformer encoder.
        max_position_embeddings: Maximum sequence length supported by the model.
        segment_vocab_size: Number of segment IDs. BERT uses two for sentence A/B.

    Returns:
        A module that maps ID tensors to summed embedding tensors.

    Raises:
        ValueError: If one of the size arguments is not positive.

    Examples:
        >>> module = BertStyleInputEmbedding(30522, 768, 512, 2)
        >>> token_ids = torch.tensor([[101, 2023, 2003, 102]])
        >>> segment_ids = torch.zeros_like(token_ids)
        >>> module(token_ids, segment_ids).shape
        torch.Size([1, 4, 768])
    """

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        segment_vocab_size: int = 2,
    ) -> None:
        super().__init__()
        if vocab_size <= 0:
            raise ValueError("vocab_size must be positive")
        if hidden_size <= 0:
            raise ValueError("hidden_size must be positive")
        if max_position_embeddings <= 0:
            raise ValueError("max_position_embeddings must be positive")
        if segment_vocab_size <= 0:
            raise ValueError("segment_vocab_size must be positive")

        self.token_embeddings = nn.Embedding(vocab_size, hidden_size)
        self.segment_embeddings = nn.Embedding(segment_vocab_size, hidden_size)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, token_ids: Tensor, segment_ids: Tensor | None = None) -> Tensor:
        """Return summed BERT-style embeddings.

        Args:
            token_ids: Tensor shaped `(batch, seq_len)` containing tokenizer IDs.
            segment_ids: Optional tensor shaped like `token_ids`. If omitted, all tokens use segment 0.

        Returns:
            Tensor shaped `(batch, seq_len, hidden_size)`.

        Raises:
            ValueError: If `token_ids` is not 2D or if `segment_ids` has a mismatched shape.

        Examples:
            >>> module = BertStyleInputEmbedding(100, 16, 32)
            >>> ids = torch.tensor([[1, 2, 3]])
            >>> module(ids).shape
            torch.Size([1, 3, 16])
        """
        if token_ids.ndim != 2:
            raise ValueError("token_ids must have shape (batch, seq_len)")

        batch_size, seq_len = token_ids.shape
        if segment_ids is None:
            segment_ids = torch.zeros_like(token_ids)
        if segment_ids.shape != token_ids.shape:
            raise ValueError("segment_ids must have the same shape as token_ids")

        position_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        position_ids = position_ids.expand(batch_size, seq_len)

        logger.debug("build embeddings: batch=%d seq_len=%d", batch_size, seq_len)
        # BERTの入力形式に合わせるため、3種類のID情報を同じhidden_sizeの空間で足し合わせる。
        embeddings = (
            self.token_embeddings(token_ids)
            + self.segment_embeddings(segment_ids)
            + self.position_embeddings(position_ids)
        )
        logger.info("created input embeddings: shape=%s", tuple(embeddings.shape))
        return self.dropout(self.layer_norm(embeddings))

関連技術

技術 Embeddingとの関係
Tokenization 文字列をtoken IDに変換し、Embeddingの入力を作る
Positional Encoding tokenの順序を表す情報を追加する
Self-Attention Embedding同士の関係を計算し、文脈化表現を作る
Sentence-BERT BERTを文類似度・検索向けに調整する
Dense Retrieval 文書や質問をベクトル化して近傍検索する

次に読むべき記事

  • Tokenization編:文字列がtoken IDになるまでを理解する
  • Attention編:Embedding同士がどう相互作用するかを理解する
  • Transformer編:BERTやGPTを支える基本構造を理解する
  • RAG/Dense Retrieval編:Embeddingを検索に使う実務設計を理解する

まとめ

Embeddingは、token IDをベクトルへ変換するLLMの入口です。

BERT論文では、token embedding、segment embedding、position embeddingを足し合わせ、Transformer Encoderへ入力します。

そして、Masked Language Modelingにより、各tokenの表現は左右の文脈を反映した文脈化Embeddingへ変わります。

Embeddingを単なる「単語を数値にする表」と見ると、BERTの本質を見落とします。

本当に重要なのは、入力EmbeddingがSelf-Attentionを通じて文脈に応じて変化し、分類、抽出、質問応答などに転用できる表現になることです。

コメント

タイトルとURLをコピーしました