剪定されたビジョントランスフォーマーに向けたディスパッチ対応ラギッド注意(Dispatch-Aware Ragged Attention)

arXiv cs.AI / 2026/4/20

📰 ニュースDeveloper Stack & InfrastructureModels & Research

要点

  • この論文は、FlashAttention-2 varlenやPyTorchのNestedTensor SDPAのような可変長注意APIを使うと、Vision Transformer(ViT)に対するトークン剪定が注意レイテンシを期待ほど減らせない理由を分析しています。
  • 分析の結果、ディスパッチ(ホスト側呼び出し)オーバーヘッドがボトルネックであることが示され、剪定後の典型的なトークン数(≤197)では行列演算は数マイクロ秒で終わる一方、ホスト側のディスパッチが60–90マイクロ秒かかると述べています。
  • 著者らは、ディスパッチのフロアを約40マイクロ秒まで下げることを狙った、軽量な双方向Triton注意カーネルを提案し、壁時計時間における剪定の効果を見えやすくします。
  • pack–attend–unpackの一連のパイプラインに組み込むことで、4つの剪定手法とDeiTの複数モデルサイズにわたり、パディング付きPyTorch SDPAに対して最大2.24×のエンドツーエンドスループットを達成し、分類予測はビット同等(最大絶対ログit差<0.007)を維持します。
  • 総じて、本研究は剪定の性能を「FLOP削減」だけでなく、ViTで一般的な短いシーケンスに対するカーネル/ディスパッチのオーバーヘッド最適化として捉え直しています。

Abstract

Vision Transformer(ViT)のトークン枝刈り(pruning)手法は、不情報なパッチを削除することで注意(attention)FLOPsを二次的に削減することを約束します。 しかし、枝刈りされた系列を、最先端の可変長注意API――FlashAttention-2のvarlenや、PyTorchのNestedTensor SDPAを含む――で実行すると、ウォールクロックの注意レイテンシはそれに応じてスケールしません。私たちはこれを、ディスパッチ(dispatch)オーバーヘッドのボトルネックに起因すると特定します。すなわち、ViTで典型的な短い、枝刈り後の系列長(<=197トークン)では、実際の行列演算は数マイクロ秒(1桁)で完了する一方で、ホスト側のディスパッチ経路が60〜90 µsを消費します。私たちは、ディスパッチの下限を約40 µsに抑える、軽量な双方向Triton注意カーネルを提示します。これは、FlashAttention-2のvarlenよりおよそ1.5倍低いため、枝刈りによる削減がウォールクロック時間でより見えるようになります。完全なpack-attend-unpackパイプラインに統合することで、私たちのシステムは、4つの枝刈りアルゴリズム(Threshold-L2、DynamicViT、EViT、ATS)すべてにおいて、パディングされたPyTorch SDPAに対して一貫して最大2.24xのエンドツーエンドスループットを達成します。さらに、DeiT-T/S/B間でスケールし、<0.007の最大絶対ロジット差でビット一致の分類予測を維持します。