大規模言語モデルにおけるアテンションヘッドの線形予測可能性

arXiv cs.LG / 2026/3/17

📰 ニュースIdeas & Deep AnalysisModels & Research

要点

  • 本論文は、事前学習済みのトランスフォーマーにおける広く存在するヘッド間の線形構造を特定し、注意機構のヘッドのQKVベクトルが同一層内の限られた数のピアヘッドの線形結合として再構成可能であることを示しています。
  • Llama-3.1-8B、Falcon3-10B、OLMo-2-7B、Qwen3-32B など複数のモデルを横断して、2〜5個の参照ヘッドが多くのターゲットヘッドを高忠実度で復元でき、C4のキーで平均R^2が約0.76、GSM8Kではしばしば0.85を超える。
  • この予測可能性はアーキテクチャによるものではなく事前学習中に学習されるようで、ランダム初期化時にはほとんど現れず、チェックポイントを経るにつれて現れ、初期化時に高い誤差を生じることを裏付ける理論的境界が存在する。
  • 本研究はこの出現を、層内におけるキー投影サブスペースの整合性が高まることと関連付けている。
  • 実践的には、著者らは参照ヘッドのKV状態のみをキャッシュし、残りをその場で再構築することを提案しており、約2倍のKVキャッシュ削減を、許容可能な小さな精度トレードオフとともに実現している。また、キーの再構築は値の再構築よりも悪影響が少ないことを示している。

要約: 大規模言語モデル(LLM)の推論はますます Key-Value(KV)キャッシュによってボトルネックとなっており、注意ヘッドの活性化の細かな構造は依然として十分には理解されていない。事前学習済みの Transformer はヘッド間に広く見られる線形構造を示す。特定のトークンに対して、注意機構のヘッドの Query、Key、Value(QKV)ベクトルは、しばしば同じ層内の限られた数の同僚ヘッドの線形結合として再構成できる。Llama-3.1-8B、Falcon3-10B、OLMo-2-7B、Qwen3-32B にまたがって、わずか 2–5 匹の参照ヘッドが多くのターゲットヘッドを高忠実度で回復する(例:C4 の Keys に対して参照が5つの場合の平均 R^2 は約 0.76、GSM8K では頻繁に R^2 > 0.85)。この予測可能性はアーキテクチャ上の性質ではなく学習によって獲得されるもので、ランダム初期化時にはほとんど欠如しており、OLMo-2 のチェックポイントを追跡する過程で事前学習中に急速に高まる。初期化時の線形予測の平均二乗誤差が大きいことを示す理論的下限によっても支持される。さらに、この出現を Key 投影サブスペースの層内整列の増加と関連付ける。最後に、この冗長性を効率化のために活用し、参照ヘッド KV 状態のみをキャッシュして残りのヘッドを軽量な線形マップでその場で再構成することで、モデル依存の精度トレードオフを伴いながら KV キャッシュを 2 倍削減している(Falcon3-10B および Qwen3-32B で 5 つのベンチマーク全体の平均で 4.5–5.5 ポイントの精度低下、Llama-3.1-8B ではより大きな低下)。また、Keys の再構成は Values の再構成より大幅に害が少ないことがわかった。