PyTorchにおけるFlashAttention(FA1〜FA4)—アルゴリズム上の違いに焦点を当てた教育用実装 [P]

Reddit r/MachineLearning / 2026/4/12

💬 オピニオンIdeas & Deep AnalysisTools & Practical UsageModels & Research

要点

  • 著者はFlashAttention-PyTorchリポジトリを更新し、FlashAttentionの各バリアントFA1〜FA4について、教育目的の実装(最適化なし、素のPyTorch)を含めました。
  • このリポジトリは、同じ注意(attention)の数式のオーケストレーションが、FA1のタイル化されたオンラインsoftmaxのベースラインから、FA4のフェーズ付きスケジューラ(main/softmax/correction)へ、さらに選択的なリスケーリングを伴ってどのように進化するかを理解するために設計されています。
  • 各バージョンでは、クエリー・タイルの所有権や、遅延正規化(FA2)などの特定のアルゴリズム上の変更点が強調されます。また、ピングポン・バッファを用いた段階的パイプライン処理と、簡略化した教育用FP8フォワード経路(FA3)、条件付きリスケーリングの仕組み(FA4)も取り上げています。
  • これらの実装は、公式のCUDA/Hopper/Blackwellカーネルをハードウェア忠実に再現することを明確に避け、読みやすいコードを通じて設計上のアイデアを可視化することに重点を置いています。
  • 著者は、このコードがバージョン間の違いを直感的に理解しやすいかどうかについてフィードバックを呼びかけ、レビュー用のGitHubリポジトリへのリンクも提示しています。

最近、FlashAttention-PyTorch リポジトリを更新し、素の PyTorch で FA1、FA2、FA3、FA4 の教育用実装が含まれるようにしました。

主な目的は、コードから各バージョン間の進展を理解しやすくすることです。

これは最適化されたカーネルのリポジトリを意図しているわけではありません。また、公式実装をハードウェアに忠実に再現したものでもありません。狙いは、アルゴリズム上のアイデアや設計上の変更点を、すぐに CUDA/Hopper/Blackwell 固有の詳細に深く踏み込まずに示すことです。

ざっくり言うと、リポジトリは次を示しています:

  • FA1: タイル化されたオンライン softmax のベースライン
  • FA2: split-Q / クエリ・タイルの所有、deferred normalization(遅延正規化)
  • FA3: ping-pong のタイル・バッファを用いた明示的な段階的パイプライン、加えて簡略化した教育用の FP8 フォワード経路
  • FA4: メイン / softmax / correction(補正)フェーズを備えた明示的なスケジューラ、ならびに条件付き/選択的なリスケーリング

そのため、まったく同じ注意(attention)の数式は保持しつつ、オーケストレーションの変更がバージョンごとに行われます。

私は、次を理解したい人のために書きました:

"FA1 → FA2 → FA3 → FA4 で、実際に何が変わったの?""

高度に最適化された CUDA カーネルから始める必要なしに。

リポジトリ: https://github.com/shreyansh26/FlashAttention-PyTorch

コードが、バージョン間の違いを直感的に理解できるものになっているかどうかについて、フィードバックがあればうれしいです。

submitted by /u/shreyansh26
[link] [comments]