純粋なTritonで融合(fused)のMoEディスパッチカーネルを書き、推論バッチサイズにおいてMixtralおよびDeepSeekでMegablocksを上回った

Reddit r/LocalLLaMA / 2026/4/6

💬 オピニオンDeveloper Stack & InfrastructureIdeas & Deep AnalysisTools & Practical Usage

要点

  • 本投稿は、Tritonのみで実装したカスタムの融合Mixture-of-Experts(MoE)推論ディスパッチ・パイプラインを紹介しており、素朴な実装に比べてMoEのフォワードパスを約5回のカーネル起動(naiveでは24回以上)まで削減している。
  • Mixtral-8x7B(A100)でのベンチマークでは、一般的なサービングのバッチ/トークンサイズにおいてPyTorchより大幅な高速化を示しており(例:〜4.9〜6.5倍)、32〜128トークンではMegablocksを上回る。一方で、512トークン以上ではMegablocksが最適化されたブロックスパース行列積により再びリードする。
  • 中核となる最適化は、ゲートとアップ射影GEMMを融合して同一のL2に常駐する入力タイルを再利用し、さらにSiLUをレジスタ上で計算することでグローバルメモリへの往復を回避し、フォワードパスあたりのメモリ転送量を大幅に削減する点にある。
  • 著者はDeepSeek-V3(256エキスパート)やQwen2-MoEに対する追加の検証も報告しており、コード変更なしでAMD MI300Xへの移植が可能で、162件すべてのテストがパスしたことを示している。
  • コードおよび詳細な書き下ろし(roofline解析を含む)はGitHubリポジトリとブログ記事として公開されており、高性能なMoE推論カーネルを実務者が利用しやすくすることを目的としている。

LLM推論のためのカスタムTritonカーネルにしばらく取り組んでいました。最新のプロジェクトは、素朴なアプローチの24回以上の代わりに、5回のカーネル起動で順伝播全体を処理する、MoEディスパッチの融合パイプラインです。

Mixtral-8x7B(A100)での結果:

Tokens PyTorchに対して Megablocksに対して
32 4.9x 131%
128 5.8x 124%
512 6.5x 89%

32および128トークン(実際の推論サービングの多くが起きる領域)では、スタンフォードのCUDA最適化Megablocksより高速です。512以上では、手チューニングされたブロックスパース行列積でMegablocksが上回ります。

重要なコツは、ゲートとアップ射影を融合して、両方のGEMMがL2キャッシュから同じ入力タイルを共有するようにし、SiLU活性はグローバルメモリに一切触れずにレジスタ上で行うことです。これにより、Mixtralの順伝播あたりのメモリ転送が約470MB削減されます。

また、DeepSeek-V3(256エキスパート)とQwen2-MoEでもテストしました。AMD MI300X上でコード変更なしでフルテストを実行し、全162テストがパスしました。

コード:https://github.com/bassrehab/triton-kernels

屋根線解析(roofline analysis)付きの完全な書き下ろし:https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/

submitted by /u/bassrehab
[link] [comments]