最近、自分にこう問いかけました。自己注意 (self-attention) の標準的なドット積を、別の距離指標(たとえば rbf カーネル)で置き換えたらどうなるのだろう、と。
標準のドット積注意には、ちょっとしたクセがあります。キーのベクトルが、巨大な大きさ(ノルム)を持つだけでソフトマックスを「いじめて」しまえるのです。だいたい正しい方向を向いているだけのランダムなキーであっても、サイズが大きければ、完全に整列しているのに短いキーを簡単に上回ります。距離ベース(RBF)注意ならこれを直せます。高い注意スコアを得るには、Q と K 実際に 高次元空間で互いに近くなっている必要があります。大きくすればごまかせるわけではありません。
これは 10 分で終わる PyTorch の簡単な実験になると思っていたのですが、ドット積が ML スタック全体にどれほど深くハードコードされているかを思い知らされました。1 つのコア演算を変えただけで、巨大なドミノ効果が起きました。:D
以下に、壊れてしまったものの連鎖と、学習できるくらいまともに動くモデルを得るために私がどう直したかを示します:
即 OOM: torch.cdist を使って(matmul のトリックなしで)素朴にペアごとのユークリッド距離を計算すると、完全な N x N の距離行列がメモリに展開されてしまいます。まともなコンテキスト長の時点で即座に OOM になります。幸い、ちょっと高校レベルの代数で、距離の二乗の公式を展開して -||Q||2 - ||K||2 + 2(Q · K) を得られます。ソフトマックスはシフト不変なので、クエリのノルムはその特定のクエリに対する定数に過ぎず、捨ててしまえます。残るのは 2(Q · K) - ||K||2 です。ここで、RBF 注意は数学的には「キーに対する二乗 L2 のペナルティが組み込まれた標準のドット積注意」そのものだということが分かります。
カスタムカーネル: その数式トリックがあっても、PyTorch のネイティブなスケールド・ドット積注意 (SDPA) は、融合ループの中でキーのノルムペナルティを任意に引くことを許してくれません。ダミー次元でテンソルをパディングしてハックすることはできますが、扱いにくいうえに不要なメモリ移動が発生するので、諦めてカスタムの Triton カーネルを書きました。これは FlashAttention のタイル分割ロジックをなぞっていますが、SRAM 上でキーの二乗 L2 ノルムをその場で計算し、ソフトマックスの直前でそれを差し引きます。すると、必要なのは線形メモリだけになります。
Attention Sinks(注意の受け皿): どうやら、モデルが「大きさで押しのける (magnitude bullying)」ことによって Attention Sink(注意の受け皿)を作る必要がある場合があるようです。彼らは役に立たないトークン(たとえば <BOS>)をスケールアップして、クエリが文脈を気にしないときに、注意マスを捨てる場所を持てるようにします。しかし距離の数学では、巨大なベクトルは無限の距離を意味し、その結果確率がゼロになり、ユークリッド空間で普遍的な受け皿になるには、キーはちょうど原点に位置していなければなりません。そこで、レジスタトークンでそれを解決しました。シーケンスの先頭に学習可能なダミーベクトルを付け、その初期化をゼロにしました。クエリが何か有用なものを見つけられない場合、それは自然にレジスタトークンへフォールバックし、実際のトークンを壊さずに注意を空のレジスタへ安全に捨てられます。
RoPE はもう意味が通らない: 現代のモデルは RoPE を使います。これはベクトルを明示的に回転させます。ドット積に対しては(相対角度として)数学的にエレガントなのですが、絶対的な空間ユークリッド距離を測る前にベクトルへ回転を適用すると、幾何学が完全に破壊されてしまい、意味をなさなくなります…。そこで私は RoPE を完全に取り払い、SuSiE(Subspace Sinusoidal Embeddings)に置き換えました。これは、キャッシュされた「回転されていない」正弦波をそのままベクトルに足し込むだけです。足し算である以上、位置による距離はユークリッド空間におけるペナルティとして明示的に働きます。
実際にうまくいったの? うーん… ある程度は。ごく小さな TinyStories のデータセットで、小さな因果モデルを学習しました。標準的な SDPA ベースラインよりもわずかに収束が速かったです。おそらく距離の数式と、ソフトマックス前のロジットが 0 にキャップされていたことで、初期の勾配スパイクが抑えられたのかもしれませんが、真相は不明です…?
大規模モデルで、近いうちに FlashAttention を置き換えることになるのか? いいえ。GPU と ML スタック全体は、純粋なドット積に対して非常に最適化されており、業界は QK-Norm で「大きさによる押しのけ」を解決してしまいました。とはいえ、ML スタックの一部を壊して組み直すという楽しいエンジニアリング演習にはなりました。
私は全部ひととおりやったので、あなたがやる必要はありません。コードはこちら:
Blog-Post: https://pisoni.ai/posts/scaled-rbf-attention/
Repo: https://github.com/4rtemi5/rbf_attention
[link] [comments]




