こんにちは r/LocalLLaMA、
モダンなMLコンパイラのスタックについて、シンプルな全体像を思いつきました。要するに、model.generate() の呼び出しからGPUがカーネルを実行するまでの間に何が起きているのか、という話です。しかし、このスタックを読むのはかなり大変です。TVMはC++の50万行超です。PyTorchはDynamo、Inductor、Tritonをその上に積み重ねています。さらにXLA、MLIR、Halide、Mojoもあります。
そこで、別のアプローチとして、最初から1つ作ることにしました。純粋なPythonと生のCUDAだけです。小さなモデル(Qwen2.5-7B、TinyLlama)を取り、それを一連のCUDAカーネルにコンパイルします。目標は今すぐTritonに勝つことではなく、コンパイラ分野で博士号がなくても修正できる(少なくとも追いやすい)ハッカブルなコンパイラを作ることです。
最終的な性能は、プロダクションスタックの約50〜90%です(PyTorch Eagerおよびtorch.compileと比較して)。
私は原理に基づいて作りました。レイヤードなパイプラインを用い、関心事を明確に分離しています:
- Torch IR — FXグラフをキャプチャ(rmsnorm、linear、softmax、...)
- Tensor IR — すべての演算をElementwise / Reduction / IndexMapに分解
- Loop IR — 複数のカーネルと融合されたループネストとして書かれるカーネル
- Tile IR — GPUへスケジュールされるカーネル(スレッド、ブロック、共有メモリ)
- Kernel IR — スケジュールをハードウェアプリミティブに具現化
- CUDA — nvccでコンパイル可能なソースとして出力
Tensor IRは、ONNXやJaxのような将来のフロントエンドを支えるために導入しています。ループ融合が長いpointwiseおよびreductionのチェーンの融合を扱います。ロワリングの段階では、タイル化したmatmul、smemステージング、ダブルバッファリングといった最適化が導入されます。
各段階は独立して検査・デバッグできます(リポジトリへのリンク)。GPUは不要です:
deplodock compile -c "nn.RMSNorm(2048)(torch.randn(1,32,2048))" --ir tensor|loop|tile|kernel|cuda ベンチマーク:
deplodock run --bench --profile -c "torch.nn.Softmax(dim=-1)(torch.randn(1,28,2048,2048))" エンドツーエンドのコンパイル:
deplodock compile Qwen/Qwen2.5-7B RMSNormの生成されたCUDAカーネルは、こんな感じです:
extern "C" __global__ __launch_bounds__(256) void k_rms_norm_reduce(const float* x, const float* p_weight, float* rms_norm) { float in0 = 2048.0f; float in1 = 1e-06f; { int a1 = blockIdx.x; int a0 = threadIdx.x; float acc0 = 0.0f; __syncthreads(); __shared__ float x_smem[2048]; for (int x_smem_flat = a0; x_smem_flat < 2048; x_smem_flat += 256) { { unsigned int _smem_addr = __cvta_generic_to_shared(&x_smem[x_smem_flat]); asm volatile("cp.async.ca.shared.global [%0], [%1], 4;
" :: "r"(_smem_addr), "l"(&x[a1 * 2048 + x_smem_flat]) : "memory"); } } asm volatile("cp.async.commit_group;
" ::: "memory"); asm volatile("cp.async.wait_group 0;
" ::: "memory"); __syncthreads(); __shared__ float p_weight_smem[2048]; for (int p_weight_smem_flat = a0; p_weight_smem_flat < 2048; p_weight_smem_flat += 256) { { unsigned int _smem_addr = __cvta_generic_to_shared(&p_weight_smem[p_weight_smem_flat]); asm volatile("cp.async.ca.shared.global [%0], [%1], 4;
" :: "r"(_smem_addr), "l"(&p_weight[p_weight_smem_flat]) : "memory"); } } asm volatile("cp.async.commit_group;
" ::: "memory"); asm volatile("cp.async.wait_group 0;
" ::: "memory"); __syncthreads(); for (int a2 = a0; a2 < 2048; a2 += 256) { float in2 = x_smem[a2]; float v0 = in2 * in2; acc0 += v0; } __shared__ float acc0_smem[256]; acc0_smem[a0] = acc0; __syncthreads(); for (int s = 128; s > 0; s >>= 1) { if (a0 < s) { acc0_smem[a0] = acc0_smem[a0] + acc0_smem[a0 + s]; } __syncthreads(); } __syncthreads(); float acc0_b = acc0_smem[0]; float v1 = acc0_b / in0; float v2 = v1 + in1; float v3 = rsqrtf(v2); for (int a3 = a0; a3 < 2048; a3 += 256) { float in3 = x_smem[a3]; float in4 = p_weight_smem[a3]; float v4 = in3 * v3; float v5 = v4 * in4; rms_norm[a1 * 2048 + a3] = v5; } } } [リンク] [コメント]

