私は、長いシーケンス全体にわたる状態の追跡が必要なタスクのために特に設計された代替アーキテクチャであるState Flow Machine (SFM)と呼んでいるプロジェクトに取り組んできました。すべてを1台の Huawei Ascend 910 ProA NPU 上で実行しています。
私が取り組みたかった核心的な問題は、トランスフォーマーは驚くべきパターンマッチャーだが、プロセスを段階的に「シミュレート」する必要がある場合、特に訓練時に見たどんなシーケンスよりも長くなるときに苦戦するということです。彼らのアテンションのパターンは本質的には学習されたショートカットであり、入力分布が変化した瞬間にそのショートカットは崩れます。
State Slots は実際には何であるか
アテンションヘッドの代わりに、モデルは明示的なメモリスロットのバンクを持っています(小さな固定サイズのベクトルを想像してください)。各トークンごとに、ゲーティング機構がどのスロットを更新しどうするかを決定します。モデルはスロットから読み取り、更新を計算して書き戻し、まるで小さな微分可能レジスタファイルのように振る舞います。
重要な直感は、タスクが「変数に対して操作を次々と適用する」場合、モデルにはその変数の現在の値を保存する場所を持ち、それを更新するべきであり、過去の全トークンに対するアテンションから全計算履歴を再構築しようとするべきではない、ということです。アテンションは「過去のどのトークンが意味を持つか」を示します。一方スロットは「現在の状態は何か、そしてこのトークンはそれをどう変えるか」を示します。
これはDeltaNet、Linear Attention、および状態空間モデル(Mamba、RWKV)からのアイデアに関連していますが、より明示的には、スロットは直接アドレス指定可能で、学習されたゲートを介して更新されるため、暗黙の再帰状態ではありません。
ベンチマーク
合成プログラム状態追跡: x = 42; x += 17; x -= 8; x *= 2; ... のようなシーケンスが与えられたとき、x の最終値(整数 0–100、101クラス分類としてフレーミング)を予測します。
- トレーニングデータ:10,000 のプログラム、操作は 10–27 個、難易度は高い(すべてのオペレーション:加算、減算、乗算、整数除算、剰余、設定)、シード 42
- 検証:同様の分布の 1,000 プログラム
- 評価:訓練プログラム長の 1×(分布内)、2×、4×、8×、16×、および 32×での評価
これは意図的におもちゃのタスクです。しかし、私が関心を持つ正確な能力を分離しています。すなわち、モデルが訓練よりもはるかに長いシーケンスでも正確な実行状態を維持できるか、ということです。
結果
完全一致精度:
| 長さ | State Slots(961Kパラメータ) | Transformer-Fair(443K) | Transformer-Large(2.2M) |
|---|---|---|---|
| 1×(10 操作) | 99.9% | 100.0% | 100.0% |
| 2×(20 操作) | 92.9% | 99.0% | 99.5% |
| 4×(40 操作) | 62.0% | 1.9% | 3.1% |
| 8×(80 操作) | 35.3% | 1.3% | 1.0% |
| 16×(160 操作) | 5.1% | 0.9% | 0.7% |
| 32×(320 操作) | 5.0% | 1.0% | 0.8% |
一般化比率(どれだけ精度を保持するか):
| モデル | 4×/1× | 8×/1× |
|---|---|---|
| State Slots | 0.62× | 0.35× |
| Transformer-Fair | 0.02× | 0.01× |
| Transformer-Large | 0.03× | 0.01× |
外挿長さでの平均絶対誤差(スケール0–100):
| 長さ | State Slots | Transformer-Fair | Transformer-Large |
|---|---|---|---|
| 4× | 14.03 | 40.33 | 36.76 |
| 8× | 26.73 | 41.71 | 41.19 |
トランスフォーマーは、4×以降では実質的にランダムに推測している(0–100のスケールでMAE が約40 は、均等なランダム推測の予想誤差に近い)。State Slots は依然として意味のある予測を行っています。
公正性の維持
これは全体を通じて大きな懸念でした。比較が意味を持つのは、両方のアーキテクチャが同じ利点を得る場合だけです:
- 同じ目的: すべてのモデルは101クラスのクロスエントロピーを使用(回帰ではなく、MSEから分類へ切り替えたのが最大の改善の1つです)。
- 同じ学習率グリッド探索: すべてのモデルを[3e-4, 5e-4, 1e-3, 2e-3, 5e-3] でテストし、2K サブセットの検証精度で最良を選択。
- 同じデータ: 同一のトレイン/バリデーション分割、同じトークナイザー、同じ難易度生成。
- 同じ精度: 全体で FP32(AMP の利点なし)。
- パラメータ比較: State Slots は 961K で、Transformer-Fair (443K) と Transformer-Large (2.2M) の間に位置します。どちらのトランスフォーマーサイズも外挿には利点になりません。
唯一の非対称性:State Slots は中間状態の監督(各操作ステップでの補助損失)を使用している点で、トランスフォーマーにはないものです。これはアーキテクチャ設計の一部と言えます。スロットは中間状態を持っていますが、それについては透明性を保ちたいと思います。
11%から99.9%への旅路
State Slots の最初のバージョン(v1)はひどかった:分布内での完全一致が11.2%。3つの変更でうまく機能するようになりました:
| バージョン | 変更点 | 1× EM | 4× EM | 4×/1× 比率 |
|---|---|---|---|---|
| v1 | MSE 回帰、LR 3e-4、補助損失なし | 11.2% | 8.9% | 0.79× |
| v2 | + 101クラス CE、+ 中間監視、+ LR スイープ | 100.0% | 87.8% | 0.88× |
| v3 (final) | + 同等の CE ヘッドを備えた公正な Transformer ベースライン、+ 16×/32× 評価 | 99.9% | 62.0% | 0.62× |
注:v2 の数値は、トランスフォーマーがまだ旧 MSE 目的を使用していたため膨らんでいました。同じ分類ヘッドと LR スイープをトランスフォーマーに与えると、分布内では追いつきました(予想どおり)が、外挿では崩れました。v3 の 4× での 62% は、正直で、 apples-to-apples の数字です。
State Slots の 4× スコア(87.8% → 62.0%)の v2 → v3 の低下は、v3 がデータを再生成し、訓練設定がわずかに異なっていたために起こりました。重要な比較は常に同じ実行内で行われるべきです。
これが証明するものではない
過大な主張を避けたいと思います:
- これは合成タスクです。 これは状態追跡のためのアーキテクチャ的帰納的バイアスについて何かを教えてくれますが、言語モデリング、コード生成、あるいは実世界での利用について直接何かを示すものではありません。
- 961K パラメータは小さい。 スケーリング挙動は未知です。より大きなスケールでは、アーキテクチャがトランスフォーマーが直面しない壁にぶつかる可能性があります。
- このタスクには明確で明示的な状態があります。 実際のプログラムにはヒープ、スタック、クロージャなど複雑な状態があります。このベンチマークは1つの整数変数だけを追跡します。
- 16×および32×はまだ悪い。 16×での5%は大きくはありません。緩やかな劣化はトランスフォーマーの崖よりはるかに良いですが、改善の余地はまだ多くあります。
- Mamba/RWKV/他のSSMsとの比較はありません。 これらは自然な競合相手で、まだベンチマークしていません。彼らがこのタスクで通常のトランスフォーマーよりも良い結果を出す可能性もあります。
今後の予定
- MambaとRWKVのベースラインを追加 — これらはサブ2乗の状態追跡の実際の競合相手です。
- アブレーション: スロット数(現状16)、補助損失の重み、忘却ゲートのバリアント。
- 難易度の高いタスク: 複数の変数、条件分岐、ループ、関数呼び出し。
- スケーリング: 利点が維持されるかを、1000万以上のパラメータでテストします。
- ハイブリッド: DeltaNetスタイルの忘却ゲートとスロットを混ぜ合わせ、両者の長所を組み合わせる可能性。
再現方法
すべて1つのNPU/GPUで動作します。コードは以下にあります: github.com/changcheng967/state-flow-machine
git clone https://github.com/changcheng967/state-flow-machine.git cd state-flow-machine python experiments/exp0_state_tracking/finish_experiment.py データセット: トレーニング10K / バリデーション1K、難易度高、シード 42。Ascend 910 ProA でのフル実行は約30分。結果は outputs/exp0/evaluation_results.json および outputs/exp0/length_generalization.png に保存されます。
質問にお答えしたり、完全なトレーニングログを共有したりするのは喜んで。
[リンク] [コメント]


