フィードバックを探しています:GoogleのTurboQuant(QJL)KVキャッシュ圧縮をMLXへ移植

Reddit r/LocalLLaMA / 2026/3/25

💬 オピニオンDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • 開発者が、Apple Silicon上でGoogle ResearchのTurboQuant(QJL)KVキャッシュ圧縮手法をMLXへ移植し、mlx_lm内のTurboKVCacheMLXで「ドロップイン」のKVキャッシュ置き換えとして実装している。
  • この方法では、高分散の「座標アウトライヤー」をFP16のまま保持し、残りの「インライヤー」は直交射影行列を用いてスケッチし、射影したインライヤーをFP16値を保持したまま1ビット符号表現に圧縮する。
  • Llama-3.2-3B(28層)のベンチマークで、生成途中にFP16からTurboKVCacheへホットスワップする実験において、KVキャッシュのメモリフットプリントを41.8%削減し、特にキーに関しては約80%の圧縮を報告している。テスト実行中にコヒーレンシ(整合性)問題は観測されなかった。
  • 全28層に対するホットスワップは約0.01秒で完了するとされており、実行時の切り替えにも十分に機能していることが示唆される。
  • 主な性能ボトルネックは、現在のmlx.coreのboolean配列/ビット演算により非効率な1ビットのビットパッキング/アンパッキングを行っている点であり、これらの操作を高速化するためにカスタムMetalカーネル(mlx.core.fast)を使う方法について助言を求めている。

こんにちは r/LocalLLaMA

私は現在、Google Research の最近の TurboQuant(QJL)論文 の概念を、Apple Silicon 向けに MLX へネイティブ実装する作業を進めています。この論文では、ほぼ精度損失ゼロのまま、KV キャッシュを大幅に圧縮(1ビット/3ビットまで)できると主張しています。

私は正常に動作する実装(TurboKVCacheMLX)をローカルの mlx_lm ライブラリへ直接組み込み、Llama-3.2-3B モデルで実運用ベンチマークを完了しました。

結果は有望ですが、「Python の壁」に当たっており、これをカスタム Metal カーネル側へ移す際のフィードバックや手がかりが欲しいです。

実装 & 実世界での結果

標準の KV キャッシュのドロップイン置き換えとして、次のことを実装しました。

  1. 外れ値の特定: 分散が大きい「座標の外れ値」(例:16 次元)を追跡し、それらは FP16 のまま保持します。
  2. インライヤ(内点)のスケッチ: 残りの「インライヤ」に直交射影行列を適用します。
  3. 量子化: 射影したインライヤを、1ビットの符号表現(> 0)として圧縮します。

ベンチマーク:Llama-3.2-3B(28 層)

テストとして、生成開始は標準の FP16 で行い、その後生成の途中でキャッシュ全体をホットスワップして、KVCache.to_turbo() という新しいメソッドで TurboQuant 側へ切り替えました。

  • 標準キャッシュ(FP16): 28.00 MB
  • Turbo キャッシュ(1ビットキー + FP16 の外れ値 + FP16 の値): 16.30 MB
  • 総メモリ削減: 総 KV キャッシュのフットプリントが 41.8% 削減(キーは特に約 80% 圧縮)
  • コヒーレンス: ホットスワップ後もモデルは完全にコヒーレントを維持しました:「universe is approximately 13.8 billion years old. The Big Bang theory is the leading explanation...」
  • 変換レイテンシ: 28 層すべてのホットスワップにかかったのは0.01 秒だけでした。

助けてほしい点 / フィードバック募集

数学は成立しており、GQA のルーティングも堅実で、メモリ削減も現実的です。ただし、ビットのパッキング/アンパッキングが現時点で最大のボトルネックです。私の _pack_bits_unpack_bits は、標準の mlx.core のブール配列とビット演算を使っていますが、これは GPU のコマンドキュー上では非常に非効率で、標準の FP16 よりセットアップが速くなることを妨げています。

MLX でまだ、1ビット量子化や重いビットパッキングをネイティブに扱ったことのある人はいませんか?

  1. カスタム Metal カーネル: 注意のドット積中に行う、この種のビットアンパックのために、mlx.core.fast でカスタム Metal カーネルをラップする例や手がかりはありますか?
  2. MLX オペレーション: 中間配列の割り当てが爆発しない形で、1ビットの符号射影を扱う、より「MLX ネイティブ」な方法はありますか?
  3. 推定器(Estimator)の最適化: QJL は事前計算されたインライヤのノルムを使って、1ビットのドット積のバイアスを取り除きます。スループットを最大化するために、これを MLX 上でより良い形に構造化できる方法はありますか?

私は PoC のロジックをオープンソースにしましたので、どんな批評でも、関連リポジトリへの案内でも歓迎します。これらの極端な量子化方式で Metal からさらに性能を絞り出すためのアドバイスがあれば、大きな助けになります。

submitted by /u/vbenjaminai
[link] [comments]