I've been playing with KV cache INT4 quantization and noticed something weird: it works perfectly on some models and completely destroys others.
Examples:
- Falcon-40B: ΔPPL +0.08 ✅ (basically free compression)
- OPT-13B: ΔPPL +0.28 ✅
- Qwen2-7B: ΔPPL +238 ❌ (output becomes incoherent garbage)
- Pythia-6.9B: ΔPPL +22 ❌
- Pythia-410M: ΔPPL +77 ❌
Same quantization method. Why does it break on some and not others?
Root cause: two independent problems
- Token-wise norm variation — KV vector norms fluctuate 2-5x across tokens in Pre-LN models. Per-row absmax gives inconsistent quantization precision across the sequence.
- Activation outlier channels — Certain channels have values 8-100x larger than average (Qwen2-7B Layer 0: 8 channels at absmax 167). They hijack the quantization scale and kill precision on all other channels.
Fixing only one doesn't help much:
- Norm separation only → +57.5 (still bad)
- Per-channel only → +97.8 (still bad)
- Both combined → +0.32 (744x improvement)
The fix (nsep+pchan)
Before quantizing, decompose each KV vector into norm (FP16 scalar) and direction (unit vector, quantized). Then use per-channel scaling. 4 lines of PyTorch:
norm = x.norm(dim=-1, keepdim=True) direction = x / norm scale = direction.abs().amax(dim=-1, keepdim=True) / 7 direction_q = (direction / scale).round().clamp(-7, 7) * scale No training. No calibration. No model-specific tuning. Drop-in preprocessing step — works in front of whatever quantization you're already using.
Results across 12 models (124M to 40B)
| Model | naive INT4 | nsep+pchan | Improvement |
|---|---|---|---|
| Qwen2-7B | +238 | +0.32 | 744× |
| Pythia-6.9B | +22 | +0.27 | 82× |
| Pythia-12B | +27 | +1.82 | 15× |
| Pythia-410M | +78 | +12.62 | 6× |
| Falcon-40B | +0.08 | +0.04 | 2× |
| OPT-13B | +0.28 | +0.35 | 1× |
Full table with all 12 models in the paper.
Key: it never hurts. Worst case degradation is +0.24 ΔPPL (OPT-125m). Models that already work fine under naive INT4 see no meaningful change.
Long context gets even crazier
At 4096 tokens on Qwen2-7B:
- naive INT4: ΔPPL +8293
- nsep+pchan: ΔPPL +0.19
- That's a 44,000× improvement
The error accumulates in attention computation as the KV cache grows. Norm separation prevents this compounding.
Bonus: INT3 > INT4 on Qwen2-7B
This one surprised me. On Qwen2-7B, INT3 (ΔPPL +6.6) is actually 36x better than INT4 (ΔPPL +238). The reason: INT4 maps mid-range values to noisy non-zeros because of the outlier-dominated scale. INT3's coarser grid maps them to clean zeros instead. In attention, clean zeros beat noisy non-zeros. Per-channel quantization fixes this for both bit widths.
PyTorch .norm() trap
While building this, I got bitten by a PyTorch API gotcha: x.norm(-1, keepdim=True) computes L_{-1} norm, NOT L2 norm. The first arg is p (norm order), not dim. Correct: x.norm(dim=-1, keepdim=True). Was invisible on CPU, exploded 3000x on CUDA. Embarrassing but posting in case it saves someone else the debugging.
Paper: https://doi.org/10.5281/zenodo.19590278
Code + all results: https://github.com/metaSATOKEN/norm-separated-quantization
Happy to answer questions. All experiments are reproducible.
[link] [comments]



