LLM推論のためのカスタムTritonカーネルにしばらく取り組んでいました。最新のプロジェクトは、素朴なアプローチの24回以上の代わりに、5回のカーネル起動で順伝播全体を処理する、MoEディスパッチの融合パイプラインです。
Mixtral-8x7B(A100)での結果:
| Tokens | PyTorchに対して | Megablocksに対して |
|---|---|---|
| 32 | 4.9x | 131% |
| 128 | 5.8x | 124% |
| 512 | 6.5x | 89% |
32および128トークン(実際の推論サービングの多くが起きる領域)では、スタンフォードのCUDA最適化Megablocksより高速です。512以上では、手チューニングされたブロックスパース行列積でMegablocksが上回ります。
重要なコツは、ゲートとアップ射影を融合して、両方のGEMMがL2キャッシュから同じ入力タイルを共有するようにし、SiLU活性はグローバルメモリに一切触れずにレジスタ上で行うことです。これにより、Mixtralの順伝播あたりのメモリ転送が約470MB削減されます。
また、DeepSeek-V3(256エキスパート)とQwen2-MoEでもテストしました。AMD MI300X上でコード変更なしでフルテストを実行し、全162テストがパスしました。
コード:https://github.com/bassrehab/triton-kernels
屋根線解析(roofline analysis)付きの完全な書き下ろし:https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/
[link] [comments]




