昨年、NVIDIA、ワルシャワ大学、エディンバラ大学に所属する研究者らがDynamic Memory Sparsification(DMS)を発表しました。これは、学習したヘッドごとのトークン退避(eviction)を用いたKVキャッシュのスパース化手法で、最大で8倍のKVキャッシュ圧縮を報告しています。
私はこの結果が興味深かったので、そのアイデアを検証するための小さなリファレンス実装とトレーナを作ることにしました。Llama 3.2 1BでWikiText-2を使うと、だいたい同等の再現ができました:
| 構成 | PPL | Delta | KLD(nats/tok) | 圧縮 |
|---|---|---|---|---|
| 通常のLlama-3.2-1B | 9.226 | - | - | 1x |
| DMS(学習済み、退避有効) | 9.200 | -0.28% | 0.026 | 6.4x |
DMSの予測器を学習するのにPRO 6000上で約20分かかり、圧縮結果は基本的に損失なし(lossless)に見えました。ただし、小さな問題が1つあって、私のHFリファレンス実装はおよそ... 18 tok/sでした。
そこで数週間カーネルを磨いた後、FastDMSを発表できて嬉しいです。MITライセンスのDMS実装で、退避されたスロットを物理的に回収するコンパクトなKVストレージを備えています。NVIDIAの元のQwen 3 8B DMSチェックポイントだけでなく、私自身が作成したLlama 3.2 1B DMSチェックポイントでもテストしています。(元のHFリファレンス版と私のトレーナもリポジトリにあります):https://github.com/shisa-ai/FastDMS
私のベンチマーク構成では、FastDMSは8KコンテキストでvLLM BF16 KVに比べて5-8x少ないKVメモリを使用しつつ、デコードもvLLMより1.5-2X速いです。
コンパクトDMSは、理論上のKVバイト数だけでなく、実際のアロケータ/デバイスのメモリを節約します。以下の表はctx_len=8192、gen_len=128を使用しています。すべてのvLLMベースラインは、ワークロードに合わせて正確にサイズ指定されたトークンプールを使います。KV/stageメモリは、キャッシュ、またはキャッシュ+ステージングのフットプリントです。vLLM BF16はdtype=bfloat16でkv_cache_dtype=autoです。vLLM FP8はkv_cache_dtype=fp8です。
| モデル/compact-DMS行 | c | vLLM BF16 KV → FastDMS KV | BF16 KV削減 | vLLM FP8 KV → FastDMS KV | FP8 KV削減 | vLLM TQ4 KV → FastDMS KV | TQ4 KV削減 |
|---|---|---|---|---|---|---|---|
| Llama-3.2-1B FastDMS デフォルト | 1 | 0.312 → 0.056 GiB | 5.6x | 0.156 → 0.056 GiB | 2.8x | 0.142 → 0.056 GiB | 2.5x |
| Llama-3.2-1B FastDMS デフォルト | 8 | 2.062 → 0.431 GiB | 4.8x | 1.031 → 0.431 GiB | 2.4x | 0.939 → 0.431 GiB | 2.2x |
| Qwen3-8B FastDMS compact DMS | 1 | 1.406 → 0.184 GiB | 7.6x | 0.703 → 0.184 GiB | 3.8x | — | — |
| Qwen3-8B FastDMS compact DMS | 8 | 9.281 → 1.462 GiB | 6.3x | 4.641 → 1.462 GiB | 3.2x | — | — |
気になっている方へ:はい、これはTurboQuantに対して速度とメモリ使用量の両方で上回っています:
| 経路 | c | プリフィル tok/s | プリフィル vs BF16 | デコード tok/s | デコード vs BF16 | KV / stageメモリ | 状態 |
|---|---|---|---|---|---|---|---|
| vLLM BF16 | 1 | 123098.0 | 1.00x | 459.4 | 1.00x | 0.312 GiB BF16 KV | 高密度BF16-KVのベースライン |
| vLLM FP8 | 1 | 119991.3 | 0.97x | 489.4 | 1.07x | 0.156 GiB FP8 KV | 高密度FP8-KVのベースライン |
vLLM TurboQuant 4bit_nc | 1 | 126429.0 | 1.03x | 333.4 | 0.73x | 0.142 GiB TQ4 KV | 4-bit KVのベースライン |
| FastDMS FP8 compact-DMS デフォルト | 1 | 123194.6 | 1.00x | 698.9 | 1.52x | 0.056 GiB | 昇格されたゼロ-BF16行 |
| FastDMS B46 int4 スピードプロファイル | 1 | 121489.9 | 0.99x | 1060.0 | 2.31x | 0.056 GiB + 0.719 GiB int4 shadow | デフォルトオフ:速度のためのストレージ |
| vLLM BF16 | 8 | 103668.5 | 1.00x | 2357.5 | 1.00x | 2.062 GiB BF16 KV | 高密度BF16-KVのベースライン |
| vLLM FP8 | 8 | 102959.5 | 0.99x | 2888.7 | 1.23x | 1.031 GiB FP8 KV | 高密度FP8-KVのベースライン |
vLLM TurboQuant 4bit_nc | 8 | 104409.9 | 1.01x | 1696.0 | 0.72x | 0.939 GiB TQ4 KV | 4-bit KVのベースライン |
| FastDMS FP8 compact-DMS デフォルト | 8 | 105531.7 | 1.02x | 3606.9 | 1.53x | 0.431 GiB | 昇格されたゼロ-BF16行 |
| FastDMS B25 narrow int4 スピードプロファイル | 8 | 104753.7 | 1.01x | 3640.7 | 1.54x | 0.431 GiB + 0.078 GiB int4 shadow | デフォルトオフ:速度のためのストレージ |
| FastDMS BF16-attention | speed control8 | 108070.5 | 1.04x | 3745.3 | 1.59x | 0.429 GiB + 0.312 GiB BF16 backing | 明示的な速度制御 |
もちろん、これらのことが圧縮タンクの出力品質に関係なければ意味がありません。理論上、DMS の eviction(退避)は FP8 の量子化 より前 に適用され、どのトークンを保持するか/退避するかを決めるため、FastDMS のコンパクトDMS の品質比較 は FP8 の量子化だけの場合と同じはずです。しかし、それでも品質は二重に確認する価値があります。
品質は、圧縮された KV キャッシュでトークンを生成し、非圧縮の参照データとトークンごとに比較することで測定します。KLD(KL divergence)が低いほど良いです。つまり、圧縮モデルの次トークン確率が参照により近いということです。トークン一致率が高いほど良いです。つまり、貪欲(greedy)デコードの出力が参照と同じであるということです。
列の読み方:
- KLD vs ref - 圧縮ロジットと参照ロジットの間の nats/token 単位の KL divergence。圧縮によって次トークン上の確率分布がどれくらいシフトするかを測ります。低いほど良いです。
0.000は同一を意味します。 - Token match - 貪欲デコードされたトークンのうち、参照と一致した割合。
96.9%は約 64 トークン中 2 トークンが異なったことを意味します。 - Tokens scored - 比較できたデコードステップ数。候補が参照と異なるトークンを生成した時点で系列は分岐し、その後のステップは比較できなくなります。
33/60は、分岐までの最初の 33 トークンのみ品質指標が対象であることを意味します。報告される KLD と PPL はそのプレフィックス上での値であり、生成全体ではありません。比率が高いほど比較がより完全になります。
テスト設定: ctx_len=1024、decode_len=16、4 つのプロンプト(合計で 60〜64 のデコードステップ)。vLLM の行は vLLM BF16 のフル-KV ロジットと比較します。FastDMS の行は eviction を無効化した FastDMS(参照ウィンドウ 100 万トークン、実質的にフルの KV キャッシュを保持)と比較します。
shisa-ai/Llama-3.2-1B-DMS-8x
| Path | Reference | KLD vs ref | Token match | PPL | Tokens scored |
|---|---|---|---|---|---|
| vLLM BF16 full KV | self | 0.000000 | 100.0% | 2.3748 | 60/60 |
| vLLM FP8 KV | vLLM BF16 | 0.005110 | 92.2% | 2.0893 | 33/60 |
vLLM TurboQuant 4bit_nc | vLLM BF16 | 0.012730 | 76.6% | 1.9606 | 22/60 |
| FastDMS FP8 compact-DMS | FastDMS no-evict | 0.003009 | 96.9% | 2.2810 | 64/64 |
nvidia/Qwen3-8B-DMS-8x
| Path | Reference | KLD vs ref | Token match | PPL | Tokens scored |
|---|---|---|---|---|---|
| vLLM BF16 full KV | self | 0.000000 | 100.0% | 1.6738 | 60/60 |
| vLLM FP8 KV | vLLM BF16 | 0.001042 | 70.3% | 1.1971 | 32/60 |
vLLM TurboQuant 4bit_nc | vLLM BF16 | 0.006039 | 84.4% | 1.4910 | 45/60 |
| FastDMS FP8 compact-DMS | FastDMS no-evict | 0.005284 | 95.3% | 1.8301 | 64/64 |
FastDMS の compact-DMS は両方のモデルで 64/64 トークンをスコアします。つまり、すべてのデコードステップが参照と比較可能であり、KLD は vLLM 自身の FP8 および TurboQuant の圧縮と比べて低い、または同等です。なお、Tokens scored が異なる場合、行間での PPL 値は直接比較できない点に注意してください。各行の PPL は異なる長さのプレフィックスに対して計算されるためです。
落とし穴は何?
それでは、これほど素晴らしいのに、なぜみんながすでに使っていないのでしょうか? 実は、生産環境のエンジン(vLLM のようなもの)にこれを実装しようとすると、大掛かりな改造 が必要になることが分かっています。DMS compact KV は、ほぼあらゆるサービング・エンジンのサブシステムに触れます:
| Subsystem | DMS で何が変わるか |
|---|---|
| PagedAttention / KV memory pool | DMS では、部分的なブロック解放を伴う、層ごと・ヘッドごとの可変トークン数が必要です。標準の固定ページブロックではありません |
| Prefill kernel | DMS 抽出後、密な KV ページを書き込むのではなく、生き残った K/V を層ごとのコンパクトストレージへストリーミングする必要があります |
| Decode kernel | 各デコードステップでヘッドごとの keep/evict を評価し、スライディングな保持ウィンドウを管理し、コンパクトストレージへ追記します |
| Attention scoring | 完全に置き換え:split-K の grouped compact decode attention を、可変長のヘッドごとのライブスパンに対して行う |
| Scheduler / admission | 密なフルシーケンスのページ数ではなく、コンパクト KV の容量に基づいてリクエストを受け入れる必要があります。これが最も難しい境界です |
| Prefix caching | DMS の eviction はシーケンス単位かつヘッド単位です。共有プレフィックスのブロックには、シーケンスごとの eviction オーバーレイが必要か、無効化する必要があります |
| Continuous batching | メモリ計上は、論理的なシーケンス長ではなく、実際に生き残ったトークン数を反映する必要があります |
これを一発でやってみようという人には幸あれ。kvcache の圧縮は確かに実在するように見え、適切な実装ができれば品質の低下はありません。そして FastDMS の実装が示しているように、DMS なしの推論よりも 高速に動かせる ようです。
(関心のある方のために、リポジトリ内にさらに多くの perf ベンチマーク、比較、そして生のログがあります)
[link] [comments]
