Ragged Paged Attention:TPU向け高性能・柔軟なLLM推論カーネル

arXiv cs.AI / 2026/4/20

📰 ニュースDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • 本論文では、動的で「ragged」な実行パターンが多いサービング環境で、TPU上でLLM推論を効率化するためのTPU向けアテンションカーネル「Ragged Paged Attention(RPA)」を提案しています。
  • RPAは、raggedメモリに対する効率的な動的スライシングを可能にするきめ細かなタイル化、KVキャッシュ更新とアテンション計算を融合する独自のソフトウェアパイプライン、デコード/プレフィル/混在ワークロード向けに専用カーネルを生成するコンパイル戦略によって性能と柔軟性を高めています。
  • Llama 3 8BをTPU7xで評価した結果、デコード時に最大86%のメモリ帯域利用率(MBU)、プレフィル時に73%のモデルFLOPs利用率(MFU)を達成しています。
  • RPAはPallasとMosaicで実装され、vLLMおよびSGLangにおいてTPUバックエンドの主要実装として統合されており、TPU推論カーネル設計の実運用レベルの基盤と実践的な知見を提供することを目指しています。

Abstract

大規模言語モデル(LLM)のデプロイは、性能と総所有コスト(TCO)の両方を重視しつつ、GoogleのTensor Processing Units(TPU)のようなコスト効率の高いアクセラレータへと、ますます移行しています。しかし、既存のLLM推論カーネルおよびサービングシステムは依然として主にGPU中心であり、特に現代のサービングで一般的な動的でラグド(ragged)な実行パターンのもとで、LLMワークロードをTPUアーキテクチャへ効率的にマッピングするための確立された手法は存在しません。本論文では、PallasとMosaicを用いて実装した、高性能かつ柔軟なTPU向けアテンションカーネル「Ragged Paged Attention(RPA)」を提案します。RPAは、次の3つの主要な技術によりこれらの課題に対処します。(1)ラグドなメモリ上で効率的な動的スライシングを可能にするきめ細かなティリング、(2)KVキャッシュ更新とアテンション計算を融合するカスタムのソフトウェアパイプライン、(3)デコード、プリフィル、混在ワークロード向けに特殊化したカーネルを生成する、分布を考慮したコンパイル戦略です。TPU7x上でLlama 3 8Bを評価したところ、RPAはデコードで最大86%のメモリ帯域利用率(MBU)、プリフィルで73%のモデルFLOPs利用率(MFU)を達成します。vLLMおよびSGLangにおける主要なTPUバックエンドとして統合されたRPAは、効率的なTPU推論のためのプロダクション品質の基盤を提供し、カーネル設計に関する実用的な知見をもたらします。