KVキャッシュINT4が一部のモデルで壊れる理由を見つけた(Qwen2-7B: ΔPPL +238)。学習も校正も不要の4行修正を作成:最大40Bまで12モデルで検証

Reddit r/LocalLLaMA / 2026/4/16

💬 オピニオンIdeas & Deep AnalysisTools & Practical UsageModels & Research

要点

  • この記事では、KVキャッシュのINT4量子化は一部のトランスフォーマーモデルでは成功する一方、他のモデルでは壊滅的に失敗し得ることが説明されており、Qwen2-7B(ΔPPL +238)とFalcon-40B(ΔPPL +0.08)の例が示されている。
  • 2つの相互に作用する根本原因が特定されている:Pre-LNモデルにおけるトークンごとのKVノルムの変動と、量子化スケールを全チャンネルの精度を崩す方向に強制してしまう活性アウトライヤーチャンネルである。
  • 提案するドロップイン修正「nsep+pchan」では、各KVベクトルをFP16のノルムスカラーと単位方向ベクトルに分解し、その後チャンネルごとのスケーリングを適用して、方向部分のみをINT4範囲にクランプしつつ量子化する。
  • 最大40Bまでの12モデルにわたる結果では、従来壊れていたケースが大幅に回復しており、Qwen2-7Bで744×、Pythia-6.9Bで82×の改善が報告されている。さらに、この修正は性能を実質的に悪化させることがないとされている。

KVキャッシュのINT4量子化をいじっていて、変なことに気づきました。うまくいくモデルがある一方で、他のモデルでは完全に壊れてしまうのです。

例:

  • Falcon-40B: ΔPPL +0.08 ✅(基本的にタダの圧縮)
  • OPT-13B: ΔPPL +0.28 ✅
  • Qwen2-7B: ΔPPL +238 ❌(出力が支離滅裂なゴミになる)
  • Pythia-6.9B: ΔPPL +22 ❌
  • Pythia-410M: ΔPPL +77 ❌

同じ量子化手法なのに、なぜ一部では壊れて一部では壊れないのでしょうか?

根本原因:2つの独立した問題

  1. トークンごとのノルム変動 — Pre-LNモデルでは、KVベクトルのノルムがトークン間で2〜5倍も揺らぎます。行ごとのabsmaxを使うと、配列全体で量子化の精度が不一致になります。
  2. アクティベーションの外れ値チャネル — 特定のチャネルは平均より8〜100倍大きい値を持ちます(Qwen2-7B Layer 0:absmax 167のチャネルが8つ)。それらが量子化スケールを乗っ取り、他の全チャネルの精度を殺してしまいます。

片方だけ直してもあまり効果がありません:

  • ノルム分離のみ → +57.5(まだダメ)
  • チャネルごとのみ → +97.8(まだダメ)
  • 両方まとめて → +0.32(744倍の改善)

修正(nsep+pchan)

量子化する前に、各KVベクトルをノルム(FP16のスカラー)と方向(単位ベクトル、量子化)に分解します。次にチャネルごとのスケーリングを使います。PyTorchは4行:

norm = x.norm(dim=-1, keepdim=True) direction = x / norm scale = direction.abs().amax(dim=-1, keepdim=True) / 7 direction_q = (direction / scale).round().clamp(-7, 7) * scale 

学習不要。キャリブレーション不要。モデル固有の調整不要。差し替え可能な前処理ステップで、すでに使っているどんな量子化にもそのまま前に置けます。

12モデルにわたる結果(124M〜40B)

Model naive INT4 nsep+pchan Improvement
Qwen2-7B +238 +0.32 744×
Pythia-6.9B +22 +0.27 82×
Pythia-12B +27 +1.82 15×
Pythia-410M +78 +12.62
Falcon-40B +0.08 +0.04
OPT-13B +0.28 +0.35

論文には、12モデルすべてを含む完全な表があります。

キー:それは決して損しません。最悪の劣化は +0.24 ΔPPL(OPT-125m)です。naive INT4ですでに問題なく動いているモデルでは、有意な変化は見られません。

長いコンテキストほどさらに狂う

Qwen2-7Bで4096トークンのとき:

  • naive INT4: ΔPPL +8293
  • nsep+pchan: ΔPPL +0.19
  • つまり44,000×の改善

KVキャッシュが大きくなるにつれて、誤差は注意計算で蓄積していきます。ノルム分離は、この“増殖”を防ぎます。

おまけ:Qwen2-7BではINT3 > INT4

これは正直驚きました。Qwen2-7Bでは、INT3(ΔPPL +6.6)のほうがINT4(ΔPPL +238)より実際に36倍優れています。理由は、INT4が外れ値に支配されたスケールのせいで、中間レンジの値をノイジーな非ゼロに割り当ててしまうからです。INT3はより粗いグリッドなので、それらをきれいなゼロに割り当てます。注意では、ノイジーな非ゼロよりきれいなゼロのほうが有利です。チャネルごとの量子化により、両方のビット幅についてこの問題が直ります。

PyTorch .norm() の罠

これを作っている最中に、PyTorch APIの落とし穴にやられました:x.norm(-1, keepdim=True) はL_{-1}ノルムを計算しており、L2ノルムではありません。最初の引数はdimではなくp(ノルムの次数)です。正しくは x.norm(dim=-1, keepdim=True)。CPUでは見えず、CUDAでは3000倍に爆発しました。恥ずかしいですが、誰かのデバッグの助けになればと思い投稿します。

論文:https://doi.org/10.5281/zenodo.19590278

コード+全結果:https://github.com/metaSATOKEN/norm-separated-quantization

質問には喜んで答えます。すべての実験は再現可能です。

submitted by /u/Afraid_Project_8666
[link] [comments]