AI Navigate

コンパイラファーストの状態空間双対性と推論のためのポータブルな$O(1)$自己回帰キャッシュ

arXiv cs.LG / 2026/3/11

Developer Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • 本論文は、NVIDIAハードウェアに依存したカスタムCUDAやTritonカーネルを不要にする、XLAコンパイラ最適化を活用した状態空間モデル推論の新たなアプローチを提案します。
  • Mamba-2アルゴリズムは、対角状態構造とチャンク可能な再帰性を持つ状態空間双対性を活用し、XLAのフュージョンやタイル処理技術に適合した効率的な計算を実現しています。
  • 実装は、プリフィルとキャッシュされた自己回帰デコードを含む完全な推論ワークフローをサポートし、$O(1)$の状態管理を実現、CPU、NVIDIA GPU、Google Cloud TPU上で修正なしに動作します。
  • TPU v6eでの性能ベンチマークでは、高効率を示し、プリフィルで最大140 TFLOPS、デコード時に64%の帯域幅利用率を達成し、結果はPyTorch/CUDAのリファレンスと完全に一致します。
  • この手法は、構造条件を満たす他の状態空間モデル再帰にも一般化可能であり、Bonsai JAXモデルライブラリ内でオープンソースとして公開されています。

概要: 状態空間モデルのリリースは通常、NVIDIAハードウェアに強く依存するカスタムのCUDAおよびTritonカーネルに結び付けられています。我々は、Mamba-2の状態空間双対性アルゴリズム——対角状態構造、チャンク可能な再帰、および静的制御フローを伴うeinsum支配の計算——がXLAのフュージョンおよびタイルパスが実際に最適化する部分にきれいにマッピングされ、カスタムカーネルが必須ではなく任意になることを示します。ハンドライティングされたカーネルなしで、XLAの下で成形された標準プリミティブとして完全な推論パス(プリフィル、キャッシュされた自己回帰デコード)を実装し、生成時にホスト同期を必要としない、コンパイル済みのオンデバイスキャッシュとして理論的なO(1)状態管理を実現します。実装は単一のJAXソースから、CPU、NVIDIA GPU、およびGoogle Cloud TPU上で修正なしに動作します。TPU v6e上で5つのモデル規模(1億3千万〜27億パラメータ)に渡り、XLA生成コードは単一ストリームプリフィルで約140 TFLOPS(15%のMFU)、デコード時には最大64%の帯域幅利用率に達します。貪欲デコードは64ステップにわたりPyTorch/CUDAの参照トークンと1対1で一致し、隠れ状態はfloat32の丸め誤差範囲内で同意します。このパターンは、同じ構造条件を満たす任意のSSM再帰に移植可能であり、成熟したXLAバックエンドを持つ任意のプラットフォームで利用可能です。実装は https://github.com/CosmoNaught/mamba2-jax にて公開され、Bonsai JAXモデルライブラリに統合されています。