[P] Pure Tritonでの融合MoEディスパッチ:推論バッチサイズにおいてCUDA最適化済みMegablocksを上回る

Reddit r/MachineLearning / 2026/4/6

💬 オピニオンSignals & Early TrendsIdeas & Deep AnalysisTools & Practical UsageModels & Research

要点

  • この記事では、(CUDAやベンダー固有のコードなしで)Tritonのみで実装したフォワードパスのMixture-of-Experts(MoE)ディスパッチ用カーネルを説明する。
  • A100上でMixtral-8x7Bを動かした場合、Tritonの手法は推論に関係するバッチサイズでStanfordのMegablocksを上回り、32トークンで131%、128トークンで124%の性能を達成する。
  • 入力タイルのロードを再利用し、SiLUをレジスタ上で計算する「融合ゲート+アップ射影」を導入することで、中間バッファの使用量を削減し、メモリ転送を約35%(フォワードパスあたり約470MB)削減する。
  • さらに、ブロックスケジューリングされたグループ化GEMMを採用し、事前計算した「block_idから(expert_id, offset)への対応」を用いることで、パディングなしで、可変サイズのエキスパートバッチを単一のカーネル起動で処理する。
  • 複数のMoEモデル(Mixtral-8x7B、256エキスパートのDeepSeek-V3、Qwen2-MoE)に対して完全なテストを通過し、コード変更なしでAMD MI300Xでも動作するとのこと。

純粋なTritonだけで、Mixture-of-Expertsモデルのフォワードパス全体を処理する fused MoE dispatch カーネルを構築しました。CUDAなし、ベンダー固有のコードなし。

Mixtral-8x7B(A100)では、推論に関係するバッチサイズで Stanford の Megablocks を上回ります(32トークンで131%、128トークンで124%)。より大きいバッチでは、期待どおり Megablocks の手でチューニングされたCUDAが先行します。

主な貢献は2つです:

  1. ゲート+アップ投影の融合 - 両方のGEMMで同じ入力タイルのロードを共有し、SiLUはレジスタで計算します。フォワードパスあたり中間バッファを約470MB削減(メモリ転送35%削減)。
  2. ブロックスケジュールされたグループ化GEMM - precomputed block_id を(expert_id, offset)へマッピングすることで、パディングなしで単一のカーネル起動において可変サイズのエキスパートバッチを処理します。

Mixtral-8x7B、DeepSeek-V3(256エキスパート)、Qwen2-MoEでテスト済み。コード変更なしで AMD MI300X 上のフルテストスイートにパスします。

コード:https://github.com/bassrehab/triton-kernels

解説:https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/

投稿者 /u/bassrehab
[リンク] [コメント]