Python 5,000行で作るハック可能なMLコンパイラ・スタック

Reddit r/MachineLearning / 2026/5/1

💬 オピニオンDeveloper Stack & InfrastructureModels & Research

要点

  • この記事は、TinyLlama や Qwen2.5-7B のような小型モデルを、6つの中間表現(IR)を経て生のCUDAカーネルへと落とし込む、純粋なPython約5,000行のリファレンスML(LLM)コンパイラを紹介しています。
  • 「ハック可能」で分かりやすいことを重視し、Tritonのような実運用コンパイラに勝つことではなく、教育と拡張性を目的にしている点が明確にされています。
  • パイプラインは、Torch FXグラフを扱うTorch IRから始まり、次に演算をElementwise/Reduction/IndexMapに分解するTensor IRへ進み、ONNXやJAXなどのフロントエンドを下流の段に手を入れずに差し替えられるようにします。
  • さらに後段では、ループIRで中間計算を融合してプロローグ/リダクション/エピローグを単一のループネストに畳み込むなど、よりハードウェアに近い変換を段階的に行い、最後にGPU向けのタイル化(Tile IR)へ進みます。
  • その実装パイプラインに対応したGitHubリポジトリ(deplodock)も提示されており、MLコンパイラのアーキテクチャを理解するための実用的な参照材料になっています。

やあ r/MachineLearning

最新のML(LLM)コンパイラ・スタックは残酷です。TVMはC++が50万行(+)です。PyTorchはDynamo、Inductor、Tritonをその上にさらに積み重ねています。さらにXLA、MLIR、Halide、Mojoもあります。これらのフレームワークの内部にいきなり放り込まれることなく、MLコンパイラの高レベル設計をカバーするチュートリアルはありません。

私は約5K行の純粋なPythonから、CUDAをそのまま出力するリファレンス・コンパイラをゼロから作りました。小さなモデル(TinyLlama、Qwen2.5-7B)を受け取り、6つのIRを通してCUDAカーネルの列へとロアします。目標はTritonに勝つことではなく、改造しやすく、追いやすい(ハック可能で、理解しやすい)コンパイラを作ることです。

全文記事: A Principled ML Compiler Stack in 5,000 Lines of Python

リポジトリ: deplodock

パイプラインは6つのIRで構成され、それぞれ最後のものよりもハードウェアに近づいていきます。次のPyTorchコードを、(名前は簡単のため短くし、コメントを追加した)実際のリファレンス・コンパイラ出力とともに、全ステージを通して見ていきます:

torch.relu(torch.matmul(x + bias, w)) # x: (16, 64), bias: (64,), w: (64, 16) 

Torch IR。キャプチャしたFXグラフで、PyTorchオペレーションの1:1ミラーです:

bias_bc = bias[j] -> (16, 64) float32 add = add(x, bias_bc) -> (16, 64) float32 matmul = matmul(add, w, has_bias=False) -> (16, 16) float32 relu = relu(matmul) -> (16, 16) float32 

Tensor IR。すべてのオペレーションをElementwise / Reduction / IndexMapに分解します。最小限の統一されたオペ面(op surface)なので、将来のフロントエンド(ONNX、JAX)でも下流のパスに触れずにそのまま差し込めます:

bias_bc = bias[j] -> (16, 64) float32 w_bc = w[j, k] -> (16, 64, 16) float32 add = add(x, bias_bc) -> (16, 64) float32 add_bc = add[i, j] -> (16, 64, 16) float32 prod = multiply(add_bc, w_bc) -> (16, 64, 16) float32 red = sum(prod, axis=-2) -> (16, 1, 16) float32 matmul = red[i, na, j] -> (16, 16) float32 relu = relu(matmul) -> (16, 16) float32 

(16, 64, 16)の中間表現は破滅的に見えますが、実体化されることはありません。次のステージで融合されます。

Loop IR。各カーネルには、隣接するカーネルと融合されたループネストがあります。プロローグ、ブロードキャストされた乗算、リダクション、出力レイアウト、エピローグがすべて、中間バッファなしで単一のループネストに畳み込まれます。

=== merged_relu -> relu === for a0 in 0..16: # free (M) for a1 in 0..16: # free (N) for a2 in 0..64: # reduce (K) in0 = load bias[a2] in1 = load x[a0, a2] in2 = load w[a2, a1] v0 = add(in1, in0) # prologue (inside reduce) v1 = multiply(v0, in2) acc0 <- add(acc0, v1) v2 = relu(acc0) # epilogue (outside reduce) merged_relu[a0, a1] = v2 

Tile IR。最初のGPUを意識したIRです。ループ軸をスレッド/ブロックへスケジューリングし、Stageが共有入力を共有メモリへ引き上げます。そして2×2のレジスタタイルによって、各スレッドが同時に4つの出力を蓄積できます。K軸は、幅32のリダクションを2回の外側イテレーションにタイル分割します。以下の3段階の注釈には、最も重い最適化がすべて詰まっています:

  • buffers=2@a2a2のKタイルループに沿ってsmem割り当てをダブルバッファリングし、イテレーションa2+1のロードが、a2の計算と重なるようにします。
  • asynccp.async.ca.shared.globalを発行して、ワープがグローバル→smem転送でブロックしないようにします;Kernel IR内のcommit_group/wait_groupフェンスと組みになります。
  • pad=(0, 1, 0) — 中央のsmem次元にパディングを1要素追加し、ワープ全体のロードがすべて同じバンクにヒットしないようにします。kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): for a2 in 0..2: # K-tile # meta: double-buffered, sync (small, no async needed) bias_smem = Stage(bias, origin=((a2 * 32)), slab=(a3:32@0)) buffers=2@a2

kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): for a2 in 0..2: # K-tile bias_smem = Stage(bias, origin=((a2 * 32)), slab=(a3:32@0)) buffers=2@a2 x_smem = Stage(x, origin=(0, (a2 * 32)), slab=(a0:8@0, a3:32@1, cell:2@0)) pad=(0, 1, 0) buffers=2@a2 async w_smem = Stage(w, origin=((a2 * 32), 0), slab=(a3:32@0, a1:8@1, cell:2@1)) buffers=2@a2 async # reduce for a3 in 0..32: in0 = load bias_smem[a2, a3] in1 = load x_smem[a2, a0, a3, 0]; in2 = load x_smem[a2, a0, a3, 1] in3 = load w_smem[a2, a3, a1, 0]; in4 = load w_smem[a2, a3, a1, 1] # prologue, reused 2× across N v0 = add(in1, in0); v1 = add(in2, in0) # 2×2 register tile acc0 <- add(acc0, multiply(v0, in3)) acc1 <- add(acc1, multiply(v0, in4)) acc2 <- add(acc2, multiply(v1, in3)) acc3 <- add(acc3, multiply(v1, in4)) # epilogue relu[a0*2, a1*2 ] = relu(acc0) relu[a0*2, a1*2 + 1] = relu(acc1) relu[a0*2 + 1, a1*2 ] = relu(acc2) relu[a0*2 + 1, a1*2 + 1] = relu(acc3) 

Kernel IR。スケジュールがハードウェアのプリミティブに具現化されます。THREAD/BLOCKはthreadIdx/blockIdxになります。async StageSmemcp.asyncに展開され、commit/waitのフェンスで埋め込みます。sync Stageはストライド付きの埋め込みループになります。フレームワーク非依存:同じIRはMetalやHIPにもロアできる可能性があります:

kernel k_relu_reduce Tile(axes=(a0:8=THREAD, a1:8=THREAD)): Init(acc0..acc3, op=add) for a2 in 0..2: # K-tile Smem bias_smem[2, 32] (float) StridedLoop(flat = a0*8 + a1; < 32; += 64): bias_smem[a2, flat] = load bias[a2*32 + flat] Sync # pad row to 33 to kill bank conflicts Smem x_smem[2, 8, 33, 2] (float) StridedLoop(flat = a0*8 + a1; < 512; += 64): cp.async x_smem[a2, flat/64, (flat/2)%32, flat%2] <- x[flat/64*2 + flat%2, a2*32 + (flat/2)%32] cp.async.commit_group; cp.async.wait_group(0); Sync Smem w_smem[2, 32, 8, 2] (float) StridedLoop(flat = a0*8 + a1; < 512; += 64): cp.async w_smem[a2, flat/16, (flat/2)%8, flat%2] <- w[a2*32 + flat/16, (flat/2)%8*2 + flat%2] cp.async.commit_group; cp.async.wait_group(0); Sync for a3 in 0..32: # reduce ... 

CUDA。Kernel IRをnvccの準備として、1対1でツリーウォークして変換します。バイアス加算、K軸のリダクション、2×2のレジスタタイル、relu活性化がすべて1つのカーネルに収まります。xbiaswのそれぞれについてHBMから1回読み込み、reluはHBMへ1回書き込みます。オペレーション間の中間値はありません。

extern "C" __global__ __launch_bounds__(256) void k_relu_reduce(const float* bias, const float* x, const float* w, float* relu) { long long tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < 64) { int a0 = tid / 8; int a1 = tid % 8; float acc0 = 0.0f, acc1 = 0.0f, acc2 = 0.0f, acc3 = 0.0f; #pragma unroll for (int a2 = 0; a2 < 2; a2++) { __shared__ float bias_smem[64]; for (int f = a0*8 + a1; f < 32; f += 64) bias_smem[a2*32 + f] = bias[a2*32 + f]; __syncthreads(); // bank conflict を避けるためにパディング __shared__ float x_smem[1056]; for (int f = a0*8 + a1; f < 512; f += 64) { unsigned int addr = __cvta_generic_to_shared( &x_smem[a2*528 + f/64*66 + f/2%32*2 + f%2] ); asm volatile( "cp.async.ca.shared.global [%0], [%1], 4;
" :: "r"(addr), "l"(&x[(f/64*2 + f%2)*64 + (a2*32 + f/2%32)]) : "memory"); } asm volatile("cp.async.commit_group;
" ::: "memory"); asm volatile("cp.async.wait_group 0;
" ::: "memory"); __syncthreads(); __shared__ float w_smem[1024]; for (int f = a0*8 + a1; f < 512; f += 64) { unsigned int addr = __cvta_generic_to_shared( &w_smem[a2*512 + f/16*16 + f/2%8*2 + f%2] ); asm volatile( "cp.async.ca.shared.global [%0], [%1], 4;
" :: "r"(addr), "l"(&w[(a2*32 + f/16)*16 + (f/2%8*2 + f%2)]) : "memory"); } asm volatile("cp.async.commit_group;
" ::: "memory"); asm volatile("cp.async.wait_group 0;
" ::: "memory"); __syncthreads(); #pragma unroll for (int a3 = 0; a3 < 32; a3++) { float in0 = bias_smem[a2*32 + a3]; float in1 = x_smem[a2*528 + a0*66 + a3*2 ]; float in2 = x_smem[a2*528 + a0*66 + a3*2 + 1]; float in3 = w_smem[a2*512 + a3*16 + a1*2 ]; float in4 = w_smem[a2*512 + a3*16 + a1*2 + 1]; float v0 = in1 + in0; float v1 = in2 + in0; acc0 += v0 * in3; acc1 += v0 * in4; acc2 += v1 * in3; acc3 += v1 * in4; } } relu[a0*2*16 + a1*2 ] = fmaxf(0.0f, acc0); relu[a0*2*16 + a1*2 + 1] = fmaxf(0.0f, acc1); relu[(a0*2+1)*16 + a1*2 ] = fmaxf(0.0f, acc2); relu[(a0*2+1)*16 + a1*2 + 1] = fmaxf(0.0f, acc3); } } 

すべてのステージは必要に応じて出力できます。GPU は不要です。

deplodock compile -c "torch.relu(torch.matmul(torch.randn(16,64) + torch.randn(64), torch.randn(64,16)))" --ir tensor|loop|tile|kernel|cuda 

eager PyTorch および torch.compile に対するベンチマーク(Qwen-block size での attention scores。コンパイラは torch.compile を同じ設定に固定します):

deplodock run --bench -c "torch.nn.Softmax(dim=-1)(torch.randn(1,28,2048,2048))" 

実モデルのエンドツーエンドなコンパイル:

deplodock compile Qwen/Qwen2.5-7B 

リンクされた記事では、設計を詳細に説明しています(RMSNorm はすべての IR を順に追い、σ ベースの融合アルゴリズムでは blowup ガードも扱います。TinyLlama および Qwen2.5-7B のブロックに対して torch.compile と照合しています)。続編の第 2 部ではコード生成の内部を取り上げます。

submitted by /u/NoVibeCoding
[link] [comments]