| みなさんこんにちは。Gated Delta Net(GDN)のアーキテクチャで作業していて、Q/K の射影を完全に取り除いても、実際のところほとんど問題ないことを見つけました。 リポジトリ: https://github.com/jfguan/shifted_gdn/blob/main/README.md 驚くことに、Gated Delta Net では、クエリとキーの射影を直接使うことで取り除けます:
TLDR: 収束が速くなり、厳密にパラメータが減っているにもかかわらず性能はわずかに良く、さらにある層のパラメータの約12.5%〜25%を節約できます。 (Stack の)コーディングサンプル 300M トークンで訓練した、約100Mパラメータのモデルでは、Shifted Key Gated Delta Net の適合訓練損失が 1.02(通常の Gated Delta Net が 1.03)でした。 また同じコンセプトは softmax アテンションには適用できないことも示します。コンセプトは Opus 4.6 で発見されました。 シフトは RWKV の token lerp に似ていますが、Q/K の射影を完全に取り除きます。 アテンションの簡単な復習アテンションは、位置 t の隠れ状態 x_t を使って、各過去トークンに対応するキー k_t と値 v_t ベクトルを生成し、さらに現在のクエリベクトル q_t も作ります。 単純化した例として、単語トークンで「空欄」を予測する必要があるとします。 キー ベクトルはトークン「何を」を符号化し、値ベクトルは「文脈の中での意味」を符号化し、クエリベクトルは「他のどのトークンが関連しているか」という現在の予測に関するものを符号化します。 この例では、クエリベクトル q_7 を使うと、q_7 · k_t が任意の過去トークン t の関連度を教えてくれます。たとえば `dog` と `barked` のほうが `The` よりも関連が高い、という具合です。 関連度スコアは softmax によって正規化され、その結果として、最終予測に役立つすべての過去の値ベクトルの重み付き平均が得られます。 線形アテンションの簡単な復習アテンションでは、すべての過去の k, v ベクトルを保持する必要があるため、計算コストはシーケンス長に比例して増大します。線形アテンションは、代わりに固定サイズの状態で回避します。 利点: メモリ/計算コストが増えない。 欠点: ただでは済みません。圧縮は本質的に情報を失い、想起(リコール)が悪化します。 仕組みの説明: 2つの k, v ベクトルがあるとき、まず外積 v⊗k を取り、これは (v · k^T) とも書けます。 その後、v⊗k にもう一度 k を掛けると、v · (k^T @ k) = v · ‖k‖² になります。 注意: v⊗k は行列です。行列に k を掛けると、k にスケールされた v が返ってきます。 各トークンの k, v を、固定サイズの行列 M に対して M += v⊗k と加算していくことで保存します。新しい k, v の組が来るたびにメモリへ追加していきます。 しかし M のサイズは固定なので、いずれすべてのキーが重なり始めます。そのため、もし2つのキーが似ていた場合、問い合わせは対応する2つの値の組み合わせを返してしまいます。つまり M は、情報を失う固定サイズの KV キャッシュだと考えられます。 実際には、さまざまなゲーティングや減衰の仕組みによって、キーの衝突/容量の問題は軽減されます。 Shifted Key(シフトされたキー)トリック通常、q, k ベクトルは学習された q, k 射影から生成されますが、shifted key トリックではその学習された射影を完全にスキップします。その代わりに、次を直接使います: (x_t は位置 t の隠れ状態):
先ほどの例に戻ると: 関連付けが次のようになります:
... 空欄を予測するために、我々の隠れ状態 x_7 は「dog」であり、x_1 と似ています。これにより「barked」のための v_2 表現が強まります。 shifted key のハードな事前知識は、通常は学習された Q/K 射影で解決される線形アテンションの対称なメモリ行列の問題を修正します。隠れ状態 x_t が k_t, v_t ベクトルの両方への入力になるため、対称なキー—値の組は次に来るものを符号化しません。たとえば、キーは「私は dog トークンである」を表し、値は「dog の意味」を表すかもしれません。shifted key がないと、現在の隠れ状態は「dog」なので、行列を問い合わせると「dog の意味」が返ってきます。これは本来欲しかった「bark の意味」とは異なります。 この対称性の問題は softmax アテンションには当てはまりません。softmax アテンションは、問い合わせのためにすべての過去キーを保持します。 また shifted key は「コピー/ペースト」とも考えられます。つまり x を見たら y を考える、という感じです。関連付けが隣接するトークンに制限されるので、かなり制約が強いように見えるのは確かです。 しかし、経験的には 100M パラメータ規模でもそれでも機能しているようです。これは、線形アテンションモデルにおいて q, k 射影が主に次のことを行っていることを示唆しているのかもしれません:
生の隠れ状態が、これらの責務を十分に、あるいはそれ以上にうまく果たしているようです。 実験免責事項 - すべてのモデルは、十分にではないにせよかなり訓練不足の状態です。カーブは、序盤の訓練の影響が大きくなりすぎないように、訓練の最後の80%に対してフィットさせています。シーケンス長は 2048、語彙数は 1024 です。 18M スケールのテスト ベースラインとして 17.9M パラメータの Gated Delta Net と、14.7M パラメータの Shifted Key Gated Delta Net を、コード例(The Stack)での 30M トークン、バッチサイズ 4 の条件で学習させます。QK の削除以外は、層数とモデル次元は同じです。 平滑化されたデータポイントでの訓練損失を見ると、パラメータ数が少なく表現力も少ないにもかかわらず、トークンシフトのほうが良い結果が得られます。 しかしトランスフォーマーでは、shifted key トランスフォーマーのほうが悪い結果になります。これは、softmax アテンションと線形アテンションが似たコンセプトから派生しているとはいえ、実際には挙動が異なることを示唆しています。どちらもパターンマッチングを行っている一方で、softmax アテンションは正確な過去キーへの問い合わせ/想起によって行い、線形アテンションはより曖昧な一般的パターンマッチングを行っているのかもしれません。 |
100Mスケールテスト
ゲート付きデルタネットでは105Mまで、シフト付きキー付きゲートデルタネットでは86.2Mまでスケールし、300Mトークンで学習します。バッチサイズは1です。
シフト付きキーのモデルは、パラメータ数が約15%少ないにもかかわらず、わずかなリードを維持しています。また、QKの射影を学習する必要がないため、収束もより速くなります。
最後に、シフト付きキーのモデルは、3つの指標によって、各層にわたって情報を保存するために自分の「キー」を「よりうまく」使っているように見えます。
- 有効ランク - どれだけ多くの異なるキーが保存されているか。
- 平均ペアワイズコサイン - 純粋な取得(クリーンなリトリーバル)において、キーがどれだけ近く、かつ「ごちゃ混ぜ」になっているか。
- 条件数 - キー全体が次元の「保存」空間をどれだけうまく使っているか。
シフト付きキーのモデルは、層0における条件数を除くすべての指標でより良い結果を示します。これは、パディング用のキーを追加したことによるアーティファクトであり、位置0には、キーとして使うべき先行する隠れ状態が存在しないためです。
結論
これがなぜうまく機能するのか、正確にはよく分かりません。関連付けを鎖のようにつなげて記憶を形成できるという直感的な説明は成り立つように見える一方で、隣接するトークン同士を直接関連付けることに制限しても性能にあまり影響が出ないのは混乱します。おそらくこれはスケールすると制約が厳しすぎるのかもしれませんが、それでも線形注意に関連するモデルが、何らかの点で本当に別物であることは示しているように見えます。
submitted by /u/jfguan[リンク] [コメント]




