こんにちは r/LocalLLaMA、
私は現在、Google Research の最近の TurboQuant(QJL)論文 の概念を、Apple Silicon 向けに MLX へネイティブ実装する作業を進めています。この論文では、ほぼ精度損失ゼロのまま、KV キャッシュを大幅に圧縮(1ビット/3ビットまで)できると主張しています。
私は正常に動作する実装(TurboKVCacheMLX)をローカルの mlx_lm ライブラリへ直接組み込み、Llama-3.2-3B モデルで実運用ベンチマークを完了しました。
結果は有望ですが、「Python の壁」に当たっており、これをカスタム Metal カーネル側へ移す際のフィードバックや手がかりが欲しいです。
実装 & 実世界での結果
標準の KV キャッシュのドロップイン置き換えとして、次のことを実装しました。
- 外れ値の特定: 分散が大きい「座標の外れ値」(例:16 次元)を追跡し、それらは FP16 のまま保持します。
- インライヤ(内点)のスケッチ: 残りの「インライヤ」に直交射影行列を適用します。
- 量子化: 射影したインライヤを、1ビットの符号表現(> 0)として圧縮します。
ベンチマーク:Llama-3.2-3B(28 層)
テストとして、生成開始は標準の FP16 で行い、その後生成の途中でキャッシュ全体をホットスワップして、KVCache.to_turbo() という新しいメソッドで TurboQuant 側へ切り替えました。
- 標準キャッシュ(FP16): 28.00 MB
- Turbo キャッシュ(1ビットキー + FP16 の外れ値 + FP16 の値): 16.30 MB
- 総メモリ削減: 総 KV キャッシュのフットプリントが 41.8% 削減(キーは特に約 80% 圧縮)
- コヒーレンス: ホットスワップ後もモデルは完全にコヒーレントを維持しました:「universe is approximately 13.8 billion years old. The Big Bang theory is the leading explanation...」
- 変換レイテンシ: 28 層すべてのホットスワップにかかったのは0.01 秒だけでした。
助けてほしい点 / フィードバック募集
数学は成立しており、GQA のルーティングも堅実で、メモリ削減も現実的です。ただし、ビットのパッキング/アンパッキングが現時点で最大のボトルネックです。私の _pack_bits と _unpack_bits は、標準の mlx.core のブール配列とビット演算を使っていますが、これは GPU のコマンドキュー上では非常に非効率で、標準の FP16 よりセットアップが速くなることを妨げています。
MLX でまだ、1ビット量子化や重いビットパッキングをネイティブに扱ったことのある人はいませんか?
- カスタム Metal カーネル: 注意のドット積中に行う、この種のビットアンパックのために、
mlx.core.fastでカスタム Metal カーネルをラップする例や手がかりはありますか? - MLX オペレーション: 中間配列の割り当てが爆発しない形で、1ビットの符号射影を扱う、より「MLX ネイティブ」な方法はありますか?
- 推定器(Estimator)の最適化: QJL は事前計算されたインライヤのノルムを使って、1ビットのドット積のバイアスを取り除きます。スループットを最大化するために、これを MLX 上でより良い形に構造化できる方法はありますか?
私は PoC のロジックをオープンソースにしましたので、どんな批評でも、関連リポジトリへの案内でも歓迎します。これらの極端な量子化方式で Metal からさらに性能を絞り出すためのアドバイスがあれば、大きな助けになります。
[link] [comments]
