Gated Delta Net(GDN)からQ/K投影を削除しても性能を維持しつつ、パラメータを約15%少なくできる

Reddit r/LocalLLaMA / 2026/4/4

💬 オピニオンIdeas & Deep AnalysisModels & Research

要点

  • Redditの投稿によると、Gated Delta Net(GDN)では、別個のQ(query)およびK(key)投影層を削除しても性能を維持できる可能性があり、層あたりのパラメータ数を約12.5%〜約25%削減できる。
  • 提案手法は、現在の隠れ状態をクエリベクトルとして、前の隠れ状態をキーべクトルとして用いることで、注意(attention)の振る舞いを保ちつつアーキテクチャを簡素化することを狙っている。
  • コーディングデータ(Stack)で、約1億(100M)パラメータのモデルを3億(300M)トークン学習した実験では、シフト版キーの方が標準のGDNよりもわずかに適合した学習損失になった(1.02 vs 1.03)。
  • 著者は、「シフトされたキー/QK投影なし」というアイデアはsoftmax attentionにはそのままは適用できないと述べており、この設計には注意機構の違いが重要であることを示唆している。
  • 本研究は既存リポジトリを参照し、このコンセプトの発見はOpus 4.6に帰しており、完全に新しい学習レシピというよりアーキテクチャ上の微調整として位置づけている。
Gated Delta Net から Q/K の射影を取り除いても、約15%少ないパラメータで性能を維持

みなさんこんにちは。Gated Delta Net(GDN)のアーキテクチャで作業していて、Q/K の射影を完全に取り除いても、実際のところほとんど問題ないことを見つけました。
シフトされたキーに対して、線形アテンションと softmax アテンションがこれほどまでに挙動を変えるのはなぜなのか、うまい説明を知っている人はいますか?

リポジトリ: https://github.com/jfguan/shifted_gdn/blob/main/README.md

驚くことに、Gated Delta Net では、クエリとキーの射影を直接使うことで取り除けます:

  1. 現在の隠れ状態をクエリベクトルとして使う
  2. 前の隠れ状態をキーベクトルとして使う

TLDR: 収束が速くなり、厳密にパラメータが減っているにもかかわらず性能はわずかに良く、さらにある層のパラメータの約12.5%〜25%を節約できます。

(Stack の)コーディングサンプル 300M トークンで訓練した、約100Mパラメータのモデルでは、Shifted Key Gated Delta Net の適合訓練損失が 1.02(通常の Gated Delta Net が 1.03)でした。

また同じコンセプトは softmax アテンションには適用できないことも示します。コンセプトは Opus 4.6 で発見されました。

シフトは RWKV の token lerp に似ていますが、Q/K の射影を完全に取り除きます。

アテンションの簡単な復習

アテンションは、位置 t の隠れ状態 x_t を使って、各過去トークンに対応するキー k_t と値 v_t ベクトルを生成し、さらに現在のクエリベクトル q_t も作ります。

単純化した例として、単語トークンで「空欄」を予測する必要があるとします。

https://preview.redd.it/jdrakf3pb3tg1.png?width=1388&format=png&auto=webp&s=ecd847d83445aa90c926f599e54bde590554f32f

キー ベクトルはトークン「何を」を符号化し、値ベクトルは「文脈の中での意味」を符号化し、クエリベクトルは「他のどのトークンが関連しているか」という現在の予測に関するものを符号化します。

この例では、クエリベクトル q_7 を使うと、q_7 · k_t が任意の過去トークン t の関連度を教えてくれます。たとえば `dog` と `barked` のほうが `The` よりも関連が高い、という具合です。

関連度スコアは softmax によって正規化され、その結果として、最終予測に役立つすべての過去の値ベクトルの重み付き平均が得られます。

線形アテンションの簡単な復習

アテンションでは、すべての過去の k, v ベクトルを保持する必要があるため、計算コストはシーケンス長に比例して増大します。線形アテンションは、代わりに固定サイズの状態で回避します。

利点: メモリ/計算コストが増えない。

欠点: ただでは済みません。圧縮は本質的に情報を失い、想起(リコール)が悪化します。

仕組みの説明:

2つの k, v ベクトルがあるとき、まず外積 v⊗k を取り、これは (v · k^T) とも書けます。

その後、v⊗k にもう一度 k を掛けると、v · (k^T @ k) = v · ‖k‖² になります。

注意: v⊗k は行列です。行列に k を掛けると、k にスケールされた v が返ってきます。

各トークンの k, v を、固定サイズの行列 M に対して M += v⊗k と加算していくことで保存します。新しい k, v の組が来るたびにメモリへ追加していきます。

しかし M のサイズは固定なので、いずれすべてのキーが重なり始めます。そのため、もし2つのキーが似ていた場合、問い合わせは対応する2つの値の組み合わせを返してしまいます。つまり M は、情報を失う固定サイズの KV キャッシュだと考えられます。

実際には、さまざまなゲーティングや減衰の仕組みによって、キーの衝突/容量の問題は軽減されます。

Shifted Key(シフトされたキー)トリック

通常、q, k ベクトルは学習された q, k 射影から生成されますが、shifted key トリックではその学習された射影を完全にスキップします。その代わりに、次を直接使います:

(x_t は位置 t の隠れ状態):

  1. v_t に対してキーベクトル k_t として x_{t-1} を使う。これにより、過去の状態が現在の値に結び付けられます。
  2. クエリベクトルとして x_t を使う。キーがシフトされているため、x_t でメモリ行列を問い合わせると「x_t に似た位置では、何が後に来たか?」が返ってきます。

先ほどの例に戻ると:

https://preview.redd.it/ysjrxyirb3tg1.png?width=1304&format=png&auto=webp&s=0118ac187d0db5ecff25e2574e208cdd3e784ddc

関連付けが次のようになります:

  1. The -> dog
  2. dog -> barked
  3. barked. -> The
  4. The -> man
  5. man -> saw

...

空欄を予測するために、我々の隠れ状態 x_7 は「dog」であり、x_1 と似ています。これにより「barked」のための v_2 表現が強まります。

shifted key のハードな事前知識は、通常は学習された Q/K 射影で解決される線形アテンションの対称なメモリ行列の問題を修正します。隠れ状態 x_t が k_t, v_t ベクトルの両方への入力になるため、対称なキー—値の組は次に来るものを符号化しません。たとえば、キーは「私は dog トークンである」を表し、値は「dog の意味」を表すかもしれません。shifted key がないと、現在の隠れ状態は「dog」なので、行列を問い合わせると「dog の意味」が返ってきます。これは本来欲しかった「bark の意味」とは異なります。

この対称性の問題は softmax アテンションには当てはまりません。softmax アテンションは、問い合わせのためにすべての過去キーを保持します。

また shifted key は「コピー/ペースト」とも考えられます。つまり x を見たら y を考える、という感じです。関連付けが隣接するトークンに制限されるので、かなり制約が強いように見えるのは確かです。

しかし、経験的には 100M パラメータ規模でもそれでも機能しているようです。これは、線形アテンションモデルにおいて q, k 射影が主に次のことを行っていることを示唆しているのかもしれません:

  1. メモリ行列における対称性を破ることを学習する
  2. キー空間を完全に活用するための良い直交キーを形成する
  3. 生の単語ではなく抽象的概念を関連付ける

生の隠れ状態が、これらの責務を十分に、あるいはそれ以上にうまく果たしているようです。

実験

免責事項 - すべてのモデルは、十分にではないにせよかなり訓練不足の状態です。カーブは、序盤の訓練の影響が大きくなりすぎないように、訓練の最後の80%に対してフィットさせています。シーケンス長は 2048、語彙数は 1024 です。

18M スケールのテスト

ベースラインとして 17.9M パラメータの Gated Delta Net と、14.7M パラメータの Shifted Key Gated Delta Net を、コード例(The Stack)での 30M トークン、バッチサイズ 4 の条件で学習させます。QK の削除以外は、層数とモデル次元は同じです。

平滑化されたデータポイントでの訓練損失を見ると、パラメータ数が少なく表現力も少ないにもかかわらず、トークンシフトのほうが良い結果が得られます。

https://preview.redd.it/amyjuncub3tg1.png?width=2024&format=png&auto=webp&s=01986c04440767d1b4efe55896610dad698d5cd7

しかしトランスフォーマーでは、shifted key トランスフォーマーのほうが悪い結果になります。これは、softmax アテンションと線形アテンションが似たコンセプトから派生しているとはいえ、実際には挙動が異なることを示唆しています。どちらもパターンマッチングを行っている一方で、softmax アテンションは正確な過去キーへの問い合わせ/想起によって行い、線形アテンションはより曖昧な一般的パターンマッチングを行っているのかもしれません。

https://preview.redd.it/0r7hsj3wb3tg1.png?width=2018&format=png&auto=webp&s=573b71a44d13c7bae84488d4dabd03bc02545638

100Mスケールテスト

ゲート付きデルタネットでは105Mまで、シフト付きキー付きゲートデルタネットでは86.2Mまでスケールし、300Mトークンで学習します。バッチサイズは1です。

https://preview.redd.it/d3ra17exb3tg1.png?width=2020&format=png&auto=webp&s=19b571c2dad95fc23e9839b0c744090a6149a300

シフト付きキーのモデルは、パラメータ数が約15%少ないにもかかわらず、わずかなリードを維持しています。また、QKの射影を学習する必要がないため、収束もより速くなります。

最後に、シフト付きキーのモデルは、3つの指標によって、各層にわたって情報を保存するために自分の「キー」を「よりうまく」使っているように見えます。

  1. 有効ランク - どれだけ多くの異なるキーが保存されているか。
  2. 平均ペアワイズコサイン - 純粋な取得(クリーンなリトリーバル)において、キーがどれだけ近く、かつ「ごちゃ混ぜ」になっているか。
  3. 条件数 - キー全体が次元の「保存」空間をどれだけうまく使っているか。

https://preview.redd.it/ns9ddrkyb3tg1.png?width=2028&format=png&auto=webp&s=26b6afce0d1bc6255b3444a35dc856f6f7790e9c

シフト付きキーのモデルは、層0における条件数を除くすべての指標でより良い結果を示します。これは、パディング用のキーを追加したことによるアーティファクトであり、位置0には、キーとして使うべき先行する隠れ状態が存在しないためです。

結論

これがなぜうまく機能するのか、正確にはよく分かりません。関連付けを鎖のようにつなげて記憶を形成できるという直感的な説明は成り立つように見える一方で、隣接するトークン同士を直接関連付けることに制限しても性能にあまり影響が出ないのは混乱します。おそらくこれはスケールすると制約が厳しすぎるのかもしれませんが、それでも線形注意に関連するモデルが、何らかの点で本当に別物であることは示しているように見えます。

submitted by /u/jfguan
[リンク] [コメント]