最近、FlashAttention-PyTorch リポジトリを更新し、素の PyTorch で FA1、FA2、FA3、FA4 の教育用実装が含まれるようにしました。
主な目的は、コードから各バージョン間の進展を理解しやすくすることです。
これは最適化されたカーネルのリポジトリを意図しているわけではありません。また、公式実装をハードウェアに忠実に再現したものでもありません。狙いは、アルゴリズム上のアイデアや設計上の変更点を、すぐに CUDA/Hopper/Blackwell 固有の詳細に深く踏み込まずに示すことです。
ざっくり言うと、リポジトリは次を示しています:
- FA1: タイル化されたオンライン softmax のベースライン
- FA2: split-Q / クエリ・タイルの所有、deferred normalization(遅延正規化)
- FA3: ping-pong のタイル・バッファを用いた明示的な段階的パイプライン、加えて簡略化した教育用の FP8 フォワード経路
- FA4: メイン / softmax / correction(補正)フェーズを備えた明示的なスケジューラ、ならびに条件付き/選択的なリスケーリング
そのため、まったく同じ注意(attention)の数式は保持しつつ、オーケストレーションの変更がバージョンごとに行われます。
私は、次を理解したい人のために書きました:
"FA1 → FA2 → FA3 → FA4 で、実際に何が変わったの?""
高度に最適化された CUDA カーネルから始める必要なしに。
リポジトリ: https://github.com/shreyansh26/FlashAttention-PyTorch
コードが、バージョン間の違いを直感的に理解できるものになっているかどうかについて、フィードバックがあればうれしいです。
[link] [comments]




