広告

対応していないAMD GPU向けに、単純なPyTorchのフラッシュ・アテンション代替を作った

Reddit r/LocalLLaMA / 2026/3/28

💬 オピニオンDeveloper Stack & InfrastructureTools & Practical Usage

要点

  • 著者は、AMD MI50(gfx906)GPU上で動画生成モデルを動かすのが難しいと述べている。理由は、PyTorchに当該アーキテクチャ向けのメモリ効率の良い/フラッシュ・アテンションのサポートがなく、そのためメモリ使用量が急増したり、実行が極端に遅くなったりするため。
  • 彼らは、一般的な“fused(融合)”アテンションの手法(Composable Kernel、AOTriton、Flash Attention ROCm、Triton)について、新しいGPU命令セット(gfx908+)を必要とするか、あるいは明確にgfx906を除外していると説明している。
  • fused attentionがない場合、PyTorchは“math SDPA”にフォールバックし、N×Nのアテンション行列全体を生成(materialize)してしまう。その結果、32GBのVRAMでは、より長い/高解像度の動画プロンプトが実行困難になる。
  • llama.cppの“対応していないGPU向けのタイル分割(tiling)フォールバック”から着想を得て、著者は「単純なPyTorchのフラッシュ・アテンション代替」を構築した。具体的には、フル行列ではなくクエリ/キーのチャンクを順に処理することで、メモリに収まるタイル単位でアテンションを計算する。
Built a simple PyTorch flash-attention alternative for AMD GPUs that don't have it

過去9か月間、手元のセットアップで32GBのMI50を2枚使っています。私の用途の大半は単にllama.cppに頼っているだけで、今は実に快適に動いています! (当時の状況と比べると大きな飛躍です)

たまに、遊び半分でComfyUIにも手を出し、新しいImageGen/AudioGenモデルを試したりしていました。ですが、私のMI50では、動画生成という特定の用途が現実的に全く不可能でした。

問題

以前Wan 2.2を触ったときのことを覚えています。シンプルな動画生成をすると、すぐにOOMになるか、あきらめて自分でプロセスを殺すまでに7〜9時間もかかるかでした。最新のLTXモデルでもうまくいきませんでした。

少し調べてみると、MI50s(gfx906)ではPyTorch上でメモリ効率の良い注意(attention)のサポートがゼロであることが分かりました。理由は、それに必要な行列積(matrix-multiplication)のコアを持っていないからです。すべての融合(fused)attention実装は、明確にgfx906を除外しています:

  • Composable Kernel(CK):MFMAの行列命令(gfx908+)が必要
  • AOTriton:コンパイル時にgfx906を拒否
  • Flash Attention ROCm:gfx90a+が必要
  • Triton:gfx906サポートは「未計画」としてクローズ

融合attentionがない場合、PyTorchはMath SDPAへフォールバックします。この場合、全N x Nのattentionスコア行列をそのまま生成(materialize)します。2.5秒の480p動画(17Kトークン)なら、attention層1つ分のスコア行列だけで26 GB。5秒の720p動画(75Kトークン)だと500 GB超です。32GBでは完全に無理です。

DIYのアプローチ

当然ながら、上記の調査結果を見てから、公式のFAサポートがないのにllama.cppが私のGPUでどう処理しているのか気になりました。すると、未対応GPU向けのフォールバックとして、汎用的なタイル(tiling)機構が用意されていることが分かりました。

それをヒントに、PyTorchでも同様のものを自作できないか検討することにしました。とはいえ、このコーディング領域は私にとって完全に新しい分野でしたが、AIの助けを借りながらなんとか道筋をつけられました。

核心となるアイデアはシンプルです。N x Nのスコア行列を一度に計算するのではなく、メモリに収まるサイズのチャンクに分割(タイル化)します。

S = Q @ K.T(17K+トークンでOOM)ではなく、まず小さなクエリチャンクごとにループして S_chunk = Q_chunk @ K.T(約1 GBに収まる)を計算し、softmaxを実行して、Vと掛けて累積します。計算は同じで、必要なメモリはO(N2)ではなくO(N)です。

理論上は簡単ですが、確実に動かすところまで持っていくのに約28回の試行が必要でした。考え出さないといけなかったこと:

うまくいったこと:

  • クエリ次元に沿ったタイル化(自動チューニングされたブロックサイズ)
  • 三段階のフォールバック:標準のチャンク化 -> オンラインsoftmax(Kタイル) -> in-placeの手動softmax
  • BF16 -> FP16への自動変換(gfx906にはBF16ハードウェアがありません)
  • ブロードキャストではなくFlattenされたGQA GEMM(ハードウェアの利用効率が向上)
  • FP16の非正規NaN問題を防ぐためのsoftmax FTZ(flush-to-zero)しきい値
  • 追加のメモリ節約のための、ランタイム安全性検証付きFFNのチャンク化

うまくいかなかった、または不要だったこと:

  • カスタムHIPカーネル — 純粋なPyTorchのmatmulでも十分速かった
  • Triton — gfx906サポートは実験的で、途中で打ち切られていた
  • 攻めたブロックサイズ — 小さいほど常に良いわけではなく、自動チューニングが最適点を見つける

到達点

カーネルは動作し、単一のMI50 32GBで以下が可能になりました:

動画生成(ComfyUI経由):

Model Resolution Duration Time Without kernel
Wan 2.2 5B 832x480 2.5s 5:04 OOM(38 GB必要)
Wan 2.2 5B 1280x720 5s 1:19:39 OOM(500+ GB必要)
LTX-2.3 22B 1280x704 音声付きで5.2s 20:18 OOM
LTX-2.3 22B 1920x1080 音声付きで5.2s 1:03:26 OOM

画像生成(Z-Image Turbo 6B via ComfyUI):

Resolution Without Kernel With Kernel Speedup VRAM Saved
512x512 22.1s / 25.6 GB 22.0s / 21.0 GB ~同等 18%
1024x1024 59.5s / 17.7 GB 57.2s / 15.4 GB 3%高速 13%
1536x1536 157.9s / 30.8 GB 112.7s / 16.4 GB 29%高速 47%

PyTorch LLM推論 — Qwen 2.5 0.5B(GQA、FP16):

Context Math SDPA With kernel Speedup
1K tokens 189 ms 178 ms 1.06x
2K tokens 437 ms 380 ms 1.15x
4K tokens 1209 ms 944 ms 1.28x
8K tokens 3985 ms 2734 ms 1.46x
16K tokens OOM 8880 ms

すべてのベンチマークは、単一のMI50 32GBで、128 GBのDDR4 RAMを搭載し、150Wの電力制限下で実施しました。

DRAMに関する重要な注意:これらのVideoGenワークフローはCPUへのオフロードに依存しており、さまざまな解像度や動画長で快適に実験するには少なくとも64 GBのDRAMが必要です。(参考として、Wan 2.2 5BとLTX 2.3に使ったワークフローは私のGitリポジトリで共有しています)

それと、気づきましたか?!

実はさらに速い!

カーネルのいちばん良い点は、Math SDPAがまだ動作できるシーケンス長においても、実際にMath SDPAを上回る性能を出していることです。分離したattentionベンチマーク(B=1、H=16、D=64、MI50上でFP16):

シーケンス長 Math SDPA noflash-attention スピードアップ 節約できたVRAM
256 0.28 ms / 47 MB 0.18 ms / 38 MB 1.6x 19%
512 0.55 ms / 79 MB 0.29 ms / 53 MB 1.9x 33%
1024 1.83 ms / 198 MB 0.85 ms / 106 MB 2.2x 46%
2048 8.72 ms / 652 MB 4.74 ms / 308 MB 1.8x 53%
4096 28.81 ms / 2424 MB 17.93 ms / 1096 MB 1.6x 55%
8192 102.42 ms / 9424 MB 72.75 ms / 1124 MB 1.4x 88%
16384 OOM 1325.69 ms / 1202 MB 唯一の選択肢

スピードアップは、おそらくL2キャッシュの利用が改善され、小さなチャンクが巨大なNxN行列を行き来してキャッシュを荒らす(thrashingする)のではなく、キャッシュ内で“熱い状態”のまま維持されるためでしょう。これはタイル状の注意(tiled attention)の基本的な性質です(Flash AttentionがNVIDIAでも高速なのと同じ理由なので)、多少正確な数値が違っても他のGPUでも同じ方向性になるはずです。私にとっては、これによりカーネルが“何でもPyTorch”の完璧なドロップイン置換になりました!

他にも役立ちそうな領域

上のベンチマークは私が個人的に実際に試したものですが、カーネルのパッチはSDPAの呼び出しをグローバルに適用します。つまりComfyUIや推論だけに限りません。理論上、次のような用途にも役立つはずです:

  • より長いコンテキストの微調整: Tier 1はautogradをサポートしているため、メモリ節約はそのまま学習に直結します。注意(attention)の途中でOOMになっていたコンテキスト長が、同じGPU上で収まるようになる可能性があります。より長いシーケンスでのLoRA微調整が現実的になります。
  • transformersを使う任意のPyTorchアプリ: diffusers、HuggingFace Transformersなど..。もし F.scaled_dot_product_attention を呼び出していて、GPU側に効率的なバックエンドがない場合、このカーネルによって使えるようになります。

gfx906から、より広いリリースへ

当初これは、私のMI50向けの単純なプライベートDIYでした。公開する予定はありませんでした。でも、そのアルゴリズムが純粋にPyTorchのmatmulであることに気づきました。フューズド注意(fused attention)を持たないすべてのAMD GPUには、まったく同じ問題があります:

  • Vega 56/64(gfx900)— MI50と同じ時代の世代で、MFMAは無し
  • RX 5600/5700(RDNA 1)— どのライブラリにもフューズド注意が無い
  • RX 6600-6900 XT(RDNA 2)— CKやAOTritonでもこれらはサポートされていない

現在、注意を多用するワークロードでMath SDPAに“固定”されているGPUの設置台数ベースは非常に大きいです。

そこで、GPUの自動検出付きの汎用的で、pipでインストールできるライブラリとしてパッケージ化しました。対応GPUでは、importするだけで完了です:

pip install noflash-attention import noflash_attention # auto-patches SDPA — done 

検出システムは起動時に、効率的なSDPAバックエンドがあるかを調べます。GPUにFlash Attentionまたはmem_efficientがあれば、そのまま何もしません。なければ自動で有効化されます。

リポジトリ: https://github.com/Lowkey-Loki-SN/noflash-attention

制限事項と貢献歓迎

以下の点については、先に正直にお伝えしたいです:

  • すべてのベンチマークは、単一のMI50 32GBからのものです。 テストするためのVega 56/64やRX 5000/6000のカードがありません。パフォーマンスはメモリ帯域幅、計算ユニット、そしてVRAMによって変わります。
  • マルチGPUは検証していません。 パッチはデータ並列(個々のSDPA呼び出しに対して動作するため)で動くはずですが、テンソル並列やリング注意はテストしていません。
  • 学習(Training): Tier 1(標準のchunked方式)はautogradをサポートしています。Tier 2と3は推論専用です。
  • torch.compileとCUDA graphs はサポートしていません(動的なブロックサイズのため)。
  • カーネル全体がvibeコードです。私は単にオーケストレーションし、テストし、方向性のアドバイスを提供しただけです。

もし上記のどれかのGPUを持っていて、このカーネルの恩恵がありそうなら、ぜひ試してみて、結果を教えてください!これはサイドプロジェクトなので、これ以上の改良に向けた継続的なコミットを保証はできませんが、不具合報告や互換性に関するフィードバックは歓迎します。コミュニティにやってもらいましょう!

ボーナス事実: ROCm 7.2 + PyTorch(ソースから)はgfx906で動く

その過程で、ROCm 7.2をgfx906で動かせるかどうかもテストしたかったのですが(公式にはサポートされていません)、答えははい、ソースからビルドすればです。私はROCm 7.2をコンパイルし、その上でそれに対してPyTorchをビルドしました。gfx906はまだ動きます!コンパイラ側のハードウェアサポート(LLVM/AMDGPU)は削除されていません。単に公式のビルドターゲットに入っていないだけです。私は1週間ほど使っていますが、現時点では安定しています。

最後は、MI50 1枚でこのカーネルを使って、LTX-2.3 22Bで生成した1080pの5秒の音声-映像クリップで締めます。

https://reddit.com/link/1s614i8/video/n3498o3alsrg1/player

submitted by /u/Lowkey_LokiSN
[link] [comments]

広告