FlashAttentionとは?AttentionをIO-awareに高速化する仕組み

FlashAttentionの全体像

FlashAttentionの全体像

TransformerのSelf-Attentionは、系列長が長くなるほど、計算量だけでなくメモリアクセスも急増します。

FlashAttentionは、Attentionの計算結果を近似せず、GPUのHBM(High Bandwidth Memory、高帯域メモリ)とSRAM(オンチップの高速小容量メモリ)の間の読み書きを減らすことで高速化する手法です。

論文では、標準Attentionで大きな \(N \times N\) 行列をHBMへ保存する設計を避け、tiling(行列を小さなブロックに分けて処理する手法)とrecomputation(逆伝播時に必要な値を再計算する手法)により、長いコンテキストのTransformer学習を高速・省メモリにできると示されています。

3文要約

FlashAttention論文の3文要約

FlashAttentionは、Self-Attentionを「数式上の計算量」だけでなく「GPUメモリ階層間のIO(データ読み書き)」から見直した論文です。

標準Attentionが \(N \times N\) のattention行列をHBMへ書き戻すのに対し、FlashAttentionはQ、K、Vをブロックに分けてSRAM上で計算し、softmaxの正規化統計だけを更新しながらexact attention(近似しないAttention)を得ます。

論文では、GPT-2のAttention計算で最大7.6倍、GPT-2学習で3倍、Long Range Arenaで2.4倍の高速化が報告され、長文脈Transformerを実用化する重要な基盤技術になりました。

論文情報

FlashAttention論文の基本情報

項目 内容
論文タイトル FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
著者 Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
研究機関 Stanford University, University at Buffalo, SUNY
発表年 2022年
arXiv提出 2022年5月27日、v2は2022年6月23日
論文リンク FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
PDF arXiv PDF
公式実装 Dao-AILab/flash-attention

本記事では、論文の主張と筆者の解釈を分けて扱います。

論文で示されているのは、「Attentionの数式を変えずに、GPU上でのメモリアクセスを削減すると高速化できる」という点です。

そのため、FlashAttentionを単なる近似Attentionや、単なるCUDA実装テクニックとして理解すると、本質を見落としやすくなります。

本記事の目的

本記事で理解するポイント

本記事では、FlashAttentionの新規性を次の順番で整理します。

  1. なぜ標準Self-Attentionは長い系列で遅く、メモリを使うのか
  2. FLOPs(浮動小数点演算回数)だけでは実行時間を説明できない理由
  3. FlashAttentionがHBMとSRAMのIOをどう減らすのか
  4. online softmax(分割した入力でもsoftmaxを正しく計算する方法)の数式
  5. recomputationにより、逆伝播時の \(N \times N\) 保存をどう避けるのか
  6. 実験結果、応用先、よくある誤解

対象読者は、TransformerのAttention式は見たことがあるが、GPUメモリ階層やCUDAカーネル最適化までは詳しくないAI初学者から中級者です。

背景:標準Self-Attentionはなぜ重いのか

標準Self-Attentionのメモリボトルネック

Self-Attentionでは、入力からQuery、Key、Valueを作り、次の式で出力を計算します。

\[
S = QK^{\mathrm{T}}
\]

\[
P = \mathrm{softmax}(S)
\]

\[
O = PV
\]

ここで、系列長を \(N\)、head dimension(1つのattention headが持つ特徴次元)を \(d\) とすると、\(Q,K,V \in \mathbb{R}^{N \times d}\) です。

一方、\(S\) と \(P\) は \(N \times N\) の行列です。

系列長が2倍になると、\(S\) と \(P\) の要素数は4倍になります。

たとえば \(N=16,384\) の場合、\(N \times N\) は約2.68億要素です。

これは、単に計算が多いだけでなく、中間行列をHBMへ書き、次の演算でまた読むことが大きな負担になります。

段階 生成される主な値 メモリ上の問題
\(QK^{\mathrm{T}}\) score行列 \(S \in \mathbb{R}^{N \times N}\) 系列長の2乗で大きくなる
masking/dropout mask済みscoreやdropout済み確率 中間値の読み書きが増える
softmax attention確率 \(P \in \mathbb{R}^{N \times N}\) 逆伝播のため保存されやすい
\(PV\) 出力 \(O \in \mathbb{R}^{N \times d}\) 出力自体は線形サイズだが、途中が重い

従来の近似Attentionは、この \(N^2\) の負担を減らすために、疎構造、低ランク近似、局所Attentionなどを使います。

ただし、論文では、多くの近似AttentionはFLOPsを減らしても、実際のwall-clock time(実時間)では高速化しない場合があると指摘しています。

理由は、GPUでは演算器が速くなった一方で、HBMへの読み書きがボトルネックになりやすいからです。

論文解説:FlashAttentionの中核アイデア

FlashAttentionのtilingとonline softmax

FlashAttentionの中核は、Attentionの数式を変えるのではなく、計算順序を変えることです。

標準Attentionでは、\(QK^{\mathrm{T}}\) 全体を作り、softmaxをかけ、\(V\) と掛けます。

FlashAttentionでは、\(Q,K,V\) をブロックに分割し、KとVのブロック、QのブロックをSRAMへ読み込み、ブロック単位でattention出力を足し合わせます。

このとき重要なのは、softmaxが行全体に依存することです。

単純に列方向を分割すると、各ブロックのsoftmaxを別々に計算して足すだけでは、全体のsoftmaxと一致しません。

FlashAttentionは、online softmaxにより、分割したブロックを処理しながら、行ごとの最大値と正規化項を更新します。

Online Softmax:分割してもsoftmaxを正しく計算する

softmaxを数値安定に計算するため、ベクトル \(x\) に対して最大値 \(m(x)\) を引きます。

\[
m(x) = \max_i x_i
\]

\[
f(x)_i = e^{x_i – m(x)}
\]

\[
\ell(x) = \sum_i f(x)_i
\]

\[
\mathrm{softmax}(x)_i = \frac{f(x)_i}{\ell(x)}
\]

ここで、\(x\) を2つのブロック \(x^{(1)}, x^{(2)}\) に分けたとします。

全体の最大値は次のように更新できます。

\[
m(x) =
\max
\left(
m(x^{(1)}),
m(x^{(2)})
\right)
\]

正規化項も、各ブロックの最大値との差を使って合わせられます。

\[
\ell(x)
=
e^{m(x^{(1)})-m(x)}\ell(x^{(1)})
+
e^{m(x^{(2)})-m(x)}\ell(x^{(2)})
\]

この式により、全列を一度に見なくても、各ブロックを処理しながらsoftmaxの分母を正しく更新できます。

FlashAttentionでは、各Qブロックについて、Kブロックを順に読み込み、scoreブロックを計算し、行ごとの \(m\) と \(\ell\)、出力 \(O\) を更新します。

結果として、\(N \times N\) のattention行列をHBMに保存しなくても、標準Attentionと同じ出力を得られます。

Tiling:HBMではなくSRAMでブロック計算する

tilingは、行列を小さなブロックに分割し、高速なSRAMへ載せて計算する手法です。

FlashAttentionでは、KとVのブロックを外側のループで読み込み、Qのブロックを内側のループで読み込みます。

各ブロックの計算はSRAM上で行い、最終的な出力だけをHBMへ書き戻します。

観点 標準Attention FlashAttention
中間attention行列 \(N \times N\) をHBMに保存しやすい 保存しない
softmax 全score行列を作ってから処理 ブロックごとに統計量を更新
メモリアクセス HBMへの読み書きが多い SRAMで計算しHBM IOを削減
Attentionの値 exact exact
実装の要件 通常の高レベル演算で書きやすい CUDAカーネルなど低レベル制御が重要

論文では、A100 GPUを例に、HBMは大容量だがSRAMより遅く、SRAMは非常に速いが小容量であると説明しています。

FlashAttentionは、この非対称なメモリ階層を前提に、読み書きの回数を減らす設計です。

Recomputation:逆伝播で何を保存し、何を再計算するか

FlashAttentionのrecomputation

学習では、順伝播だけでなく逆伝播も必要です。

標準的な実装では、逆伝播のためにscore行列 \(S\) やattention確率 \(P\) を保存します。

しかし、これらは \(N \times N\) であり、長い系列ではメモリを大きく圧迫します。

FlashAttentionは、順伝播では出力 \(O\) とsoftmaxの正規化統計 \(m,\ell\) を保存します。

逆伝播では、Q、K、Vのブロックを再びSRAMへ読み込み、必要なscoreやsoftmax値をブロックごとに再計算します。

保存するもの 標準Attention FlashAttention
出力 \(O\) 保存 保存
score \(S\) 保存されやすい 保存しない
attention確率 \(P\) 保存されやすい 保存しない
softmax統計 \(m,\ell\) 通常は主役ではない 保存して再計算に使う
逆伝播時の方針 保存した大きな中間値を読む 小さな統計からブロックを再計算

recomputationは、FLOPsを増やす可能性があります。

一見すると、再計算するなら遅くなりそうです。

しかし、論文では、HBM IOを減らせるため、総合的には逆伝播も高速化できると示しています。

ここがFlashAttentionの重要なポイントです。

FLOPsを少し増やしても、遅いメモリアクセスを大きく減らせるなら、実時間は短くなる場合があります。

IO複雑性:FLOPsではなくメモリアクセスを見る

FlashAttentionの論文は、IO complexity(メモリ階層間の読み書き回数の複雑性)を分析しています。

論文では、標準Attentionが少なくとも次のオーダーのHBMアクセスを必要とすると整理しています。

\[
\Omega(Nd + N^2)
\]

一方、FlashAttentionは、SRAMサイズを \(M\) とすると、次のオーダーでHBMアクセスを抑えられると分析しています。

\[
O(N^2 d^2 M^{-1})
\]

ここで、\(M\) が大きいほど、SRAMへ載せられるブロックが大きくなり、HBMへの往復を減らしやすくなります。

この式は、FlashAttentionが「Attentionの計算量 \(O(N^2d)\) を消す」手法ではないことを示しています。

むしろ、同じexact attentionを計算しながら、遅いメモリ階層へのアクセスを減らす手法です。

指標 見ているもの FlashAttentionでの意味
FLOPs 演算回数 recomputationにより増える場合がある
メモリ容量 保存すべきテンソル量 \(N \times N\) attention行列を保存しないため削減
HBM IO HBMとの読み書き量 主な削減対象
wall-clock time 実際の実行時間 GPUや実装、系列長に依存する

FlashAttentionの意義は、理論上の演算量だけを見ると見えにくいボトルネックを、GPU実装の観点から正面に置いた点にあります。

Block-sparse FlashAttentionとの違い

Attention手法の比較

論文では、FlashAttentionをblock-sparse attention(ブロック単位で疎なAttentionを行う近似手法)へ拡張しています。

通常のFlashAttentionはexactです。

つまり、計算順序は変えますが、Attentionとして得られる値は標準Attentionと同じです。

一方、block-sparse FlashAttentionは、見るブロックを制限するため近似になります。

手法 Attentionの値 主な狙い 注意点
標準Attention exact 実装が単純で汎用的 \(N \times N\) 中間行列が重い
近似Attention approximate FLOPsやメモリを減らす 品質低下や実時間高速化不足が起こり得る
FlashAttention exact HBM IOを減らす GPUカーネル実装が重要
Block-sparse FlashAttention approximate より長い系列へ拡張 sparsity設計が品質に影響する

この違いは重要です。

「FlashAttentionは近似Attentionだから速い」と説明すると誤りです。

通常のFlashAttentionは、標準Attentionと同じ結果を、よりIO効率のよい順序で計算する手法です。

実験結果:どの程度速くなったのか

FlashAttentionの実験結果

論文では、モデル学習、長文脈タスク、Attention単体のベンチマークでFlashAttentionを評価しています。

主な報告値を整理すると次の通りです。

評価 論文で報告された結果 読み方
BERT-large, seq. length 512 MLPerf 1.1の学習速度記録に対して15%のend-to-end高速化 Attention単体ではなく学習全体の改善
GPT-2, seq. length 1K HuggingFace/Megatron-LM系ベースラインに対して3倍高速化 長い系列で効果が出やすい
Long Range Arena, seq. length 1K-4K 2.4倍高速化 長文脈タスクでの学習時間改善
GPT-2のAttention計算 PyTorch実装に対して最大7.6倍高速化 Attentionカーネル単体の改善
GPT-2 perplexity 長い文脈により0.7改善 高速化により長文脈を扱いやすくなる
Path-X FlashAttentionで61.4% accuracy 16K系列への拡張例
Path-256 Block-sparse FlashAttentionで63.1% accuracy 64K系列への拡張例

これらは論文の実験条件での報告値です。

すべてのモデル、GPU、sequence lengthで同じ倍率が出るわけではありません。

ただし、長い系列でAttentionの中間行列が支配的になるほど、FlashAttentionのIO削減は効きやすくなります。

高画質タスクへの応用:画像、動画、VLMで何が嬉しいか

高画質タスクへの応用

FlashAttentionは、LLMだけでなく、画像・動画・VLM(Vision-Language Model、画像と言語を同時に扱うモデル)でも重要です。

画像をパッチに分けてTransformerへ入力するVision Transformerでは、高解像度になるほどtoken数が増えます。

動画では、空間方向のパッチに加えて時間方向のフレームも増えるため、系列長はさらに大きくなります。

応用先 Attentionが重くなる理由 FlashAttentionで期待できること
高解像度画像認識 パッチ数が増える より大きな入力解像度を扱いやすくなる
画像生成 latentやパッチ系列が長くなる DiT系モデルのAttentionメモリを抑えやすい
動画理解・動画生成 フレーム数と空間パッチが同時に増える 長い時間文脈を扱いやすくなる
文書OCR・レイアウト理解 画像パッチと文字tokenが増える 長文書の文脈統合に向く
VLM 画像tokenと言語tokenを同時に扱う 長いマルチモーダル文脈に向く

カメラ向けの高画質化タスクで考えると、Denoise(ノイズ除去)、Demosaic(カラーフィルタ配列からRGB画像を復元する処理)、Super Resolution(超解像)などは、高解像度画像を扱うためメモリ負担が大きくなりやすいです。

Transformer型の画像復元モデルへFlashAttentionを組み込めれば、より大きな画像パッチ数や広い文脈を扱いやすくなる可能性があります。

一方で、組み込み機器やエッジデバイスでは、GPUの種類、SRAM容量、対応カーネル、電力制約が大きく効きます。

FlashAttentionは「考え方」としては有用ですが、実際に高速化できるかは、対象ハードウェアで同等のメモリ階層制御ができるかに依存します。

実装者視点で気をつけること

FlashAttentionを使うときは、単にライブラリ名を有効にするだけでなく、前提条件を確認する必要があります。

確認項目 なぜ重要か
GPUアーキテクチャ 対応カーネルや最適化度が変わる
head dimension カーネルが対応する形状でないと速くならない場合がある
sequence length 短い系列では通常実装との差が小さい場合がある
dropout/mask causal maskやpadding mask対応で実装経路が変わる
dtype fp16、bf16、fp8などで対応状況が異なる
training/inference 逆伝播があるかでrecomputationの意味が変わる

FlashAttentionは、Attention式を理解するだけではなく、GPUカーネル、メモリ階層、テンソル形状まで含めて効果が決まります。

LLMや画像生成モデルの高速化で使う場合も、ベンチマークは対象モデル・対象GPU・対象batch sizeで測る必要があります。

よくある誤解

FlashAttentionのよくある誤解

よくある誤解 正確な情報・解釈
FlashAttentionは近似Attentionである 通常のFlashAttentionはexact attentionです。計算順序を変えているだけです
FlashAttentionはFLOPsを減らす手法である 主役はHBM IO削減です。recomputationによりFLOPsは増える場合もあります
\(N^2\) の問題が完全になくなる attentionの計算構造自体は \(N^2\) の側面を残します。主に保存メモリとIOを減らします
どんなモデルでも必ず速い sequence length、GPU、head dimension、実装経路に依存します
CUDAの小技にすぎない online softmax、tiling、IO complexity解析を含むアルゴリズムと実装の共同設計です
Block-sparse FlashAttentionもexactである block-sparse版は見るブロックを制限するため近似です

特に重要なのは、FlashAttentionが「近似しないまま高速化する」点です。

多くのAttention高速化手法は、計算対象を減らすことで高速化します。

FlashAttentionは、計算対象そのものを削るのではなく、同じ計算をGPUにとって効率のよい順序で実行します。

まとめ

FlashAttentionのまとめと次に読むべき記事

FlashAttentionは、TransformerのAttentionを、FLOPsだけでなくGPUメモリ階層のIOから見直した論文です。

標準Attentionでは \(N \times N\) のscore行列やattention確率をHBMへ保存しやすく、長い系列でメモリと読み書きが支配的になります。

FlashAttentionは、tiling、online softmax、recomputationを組み合わせ、\(N \times N\) attention行列を保存せずにexact attentionを計算します。

論文では、BERT-large、GPT-2、Long Range Arena、Path-X/Path-256などで、長文脈Transformerの高速化とメモリ削減が報告されています。

ただし、FlashAttentionは魔法の高速化ではありません。

効果はGPU、系列長、head dimension、dtype、mask、実装カーネルに依存します。

それでも、現在のLLM、画像生成、VLMで長いコンテキストを扱ううえで、「Attentionは数式だけでなくメモリアクセスが重要」という視点を広めた点で、非常に重要な論文です。

関連技術

関連技術 FlashAttentionとの関係
Transformer Self-Attentionを中核に持つモデル構造
Self-Attention FlashAttentionが高速化する対象
CUDA kernel fusion 複数演算を1つのGPUカーネルにまとめ、メモリ読み書きを減らす実装技術
Gradient checkpointing 中間値を保存せず再計算する考え方。FlashAttentionのrecomputation理解に近い
Sparse Attention attention対象を疎にする近似手法。block-sparse FlashAttentionと関係する
FlashAttention-2 FlashAttentionの後続研究で、並列化とワーク分割を改善する
KV Cache 推論時の過去Key/Value保存。FlashAttentionとは別軸だが長文脈LLMで併用される

次に読むべき記事

記事案 狙い 優先度
Self-Attentionとは?Transformerを支える仕組みを数式で解説 FlashAttentionの前提を理解する
GPUメモリ階層とは?HBM、SRAM、レジスタの違いを整理 IO-awareの背景を理解する
FlashAttention-2とは?並列化とワーク分割の改善を解説 FlashAttentionの発展を追う
KV Cacheとは?LLM推論のメモリ使用量を理解する 長文脈推論の別ボトルネックを理解する
Sparse Attentionとは?長文脈Transformerの近似手法を比較 block-sparse版との違いを理解する

コメント

タイトルとURLをコピーしました