注意だけで十分?でも払えないのはその分だけ|Hybrid Attention

Reddit r/artificial / 2026/4/7

💬 オピニオンDeveloper Stack & InfrastructureIdeas & Deep AnalysisModels & Research

要点

  • 開発者が、PyTorchでスクラッチから学習したRust向けのバイトレベルGPT系言語モデルを公開した。25.6Mパラメータ、512のコンテキスト長、RTX 4060 Ti 8GBを1枚用いて30kステップ学習した。
  • モデルの主要なアーキテクチャ上の変更点は「HybridAttention」で、局所的なウィンドウ付きの因果(causal)注意に、GRUのようなリカレント状態パスと、両者を混ぜ合わせる学習済みゲートを組み合わせることで、長距離モデリングのコストを削減する。
  • 訓練実行の報告では、HybridAttention + KVキャッシュにより推論効率が向上し、286.6 tokens/秒を達成したとされており、これはフルアテンションに比べて51.47倍高速だと主張されている。
  • 記事では、コーパス構築が最も重要だったとしており、Rust中心のデータを約31MBから、上位500のRustクレートを複製して用いることで約173.5MBまで拡大した。
  • 著者は「システムズ(systems)」アプローチを取り、多くの実験的なメモリ最適化を無効化しつつ、単純なアーキテクチャとより良いデータが、強い損失/パープレキシティの結果に十分だと論じている。

リポジトリ: https://codeberg.org/JohannaJuntos/Sisyphus

私はPyTorchからスクラッチで、小さなRust特化の言語モデルを作っています。ファインチューニングではなく、バイトレベルで、Rust比率の高いこのリポジトリのコーパスを使い、ランダム初期化から学習しました。

実行結果:

  • 25.6Mパラメータ
  • 512コンテキスト長
  • 173.5Mバイトのコーパス
  • 30k学習ステップ
  • 単一RTX 4060 Ti 8GB
  • 最終学習損失: 0.5834 / 検証損失: 0.8217 / 混同行列損失: 2.15
  • 推論: HybridAttention + KVキャッシュで286.6 tok/s — 通常の注意(フルアテンション)に対して51.47倍

背景

私は自閉スペクトラム症のシステムプログラマで、2008/2009からコードを書いています。Cから始めました。私はMLをシステム開発のように捉えています。データ経路を理解する、メモリ挙動を理解する、スタックを小さく保つ、正当化できるときだけ複雑さを足す。これがこのリポジトリのだいたいの形です。

アーキテクチャ

バイトレベルのGPT風デコーダ:

  • 語彙サイズ256(バイト)
  • 8層、8ヘッド、512埋め込み次元
  • 学習済み位置埋め込み
  • 埋め込み / LMヘッドの重みを共有

注意ブロックは標準的なフルアテンションではありません。各層はHybridAttentionを使い、次を組み合わせます:

  1. 局所ウィンドウの因果(causal)注意
  2. GRUのような再帰的ステート経路
  3. 2つを混ぜ合わせる学習済みゲート

局所経路は短距離の構文を扱います。再帰経路は、二次コストを払わずに圧縮された長距離ステートを運びます。ゲートのバイアスは1で初期化しておき、学習の初期は局所寄りから始まるようにしています。

推論経路は、Tritonで最適化されたカーネルと、ローカルウィンドウ注意のためのtorch.libraryのカスタムオペレータを使用します。

コーパス

おそらく、これがこのリポジトリで最も重要な部分です。

実行は公式のRustドキュメント、コンパイラ/ライブラリ/テスト、cargo、rust-analyzer、tokio、serde、ripgrep、clap、axumから始めます。これは約31MBです。上位500クレートを取得することでコーパスを177,151,242バイトまで拡張しました(461件成功クローン)。

31Mから173.5M文字へのコーパス拡張が、このリポジトリの中で他の何よりも効きました。

学習

AdamW、lr 2e-4、weight decay 0.1、betas(0.9, 0.95)、30kステップ、1k warmup。7.6 GiBのカードでの学習メモリ使用量は約678.8 MiB。

勾配量子化、アクティベーション圧縮、選択的バックプロパゲーション、勾配ページングといった、あらゆる実験的メモリ技術は無効化しました。小さなカスタムアーキテクチャ + 混合精度 + より良いコーパスで十分でした。

損失曲線:

  • Step 0: train 5.5555 / val 5.5897
  • Step 1000: train 2.4295 / val 2.6365
  • Step 5000: train 0.9051 / val 1.0060
  • Step 10000: train 0.8065 / val 0.8723
  • Step 18500: train 0.6902 / val 0.7757
  • Step 29999: train 0.5834 / val 0.8217

最良の検証損失はstep 18.5k付近 — 過学習か、後半で停滞(プラトー)しているようです。

推論性能

  • フルアテンション O(n²): 17.96s / 5.6 tok/s
  • HybridAttention O(n·W + n·D): 0.35s / 286.6 tok/s
  • 速度向上: 51.47x — 品質低下なし

KVキャッシュ戦略: VRAM内のW=64トークンのホットウィンドウ(約256KB)、古いトークンは8-bitの大きさ+角度に圧縮、必要に応じて選択的に昇格。複雑さは、このモデルではO(n²·d)からO(4096n)へ。

5つのテストすべてパス: フォワードパス、キャッシュあり/なしでの生成、RNN状態の隔離、ウィンドウ機構。

生成品質

表面上のRust構文はそれなりに見えます。importやシグネチャはもっともらしく見えることがあります。しかし意味は弱く、反復や再帰的なナンセンスは依然としてよく起きます。現状の正直な読み取りです。

本当に面白いと思うこと

4つの独立した実験、それぞれ出荷可能な動くコードになりました:

  1. バイトレベルのRustのみでの事前学習
  2. Hybridローカル注意 + 再帰ブロックで、標準的なフル注意を置き換え
  3. コアリポジトリから、より広いクレートエコシステムへ向けたコーパス拡張
  4. 本番対応のホット/コールドKVキャッシュ・ページング — 51.47xの高速化、品質低下なし

最も分かりやすい勝ち筋はコーパス拡張です。次点の勝ち筋は、HybridAttention + キャッシュが、コンシューマ向けハードウェアで実際のインタラクティブ用途に十分な速さになっていることです。

次にやること

  1. アブレーション(除去実験) — HybridAttention vs ローカルのみ vs RNNのみ
  2. チェックポイント選択 — step 18.5kは29999より良い生成をするのか?
  3. 構文バリデーション — 出力はパース/コンパイル/型チェックできるのか?
  4. コンテキスト長のスイープ — 256から2048まで、どこでウィンドウサイズが悪さをするのか?
  5. バイト vs BPE — コーパスが5.6倍になった今、テストする価値はあるのか?

サブへの質問:

  1. 小さなコードモデルでは、混同行列損失(perplexity)以外で実際に役に立った評価は何でしょうか?
  2. コード生成向けに、ハイブリッドのローカル注意 + 再帰的注意がうまく機能するのを見た人はいますか?それとも、多くの場合はただのTransformerをスケールするだけに負けがちでしょうか?
  3. このセットアップがあるなら、より多くのトークン、より長いコンテキスト、あるいは最初により綺麗なアブレーションのどれを優先しますか?
投稿者: /u/Inevitable_Back3319
[link] [comments]