Mamba 1 & 2 から Mamba 3 へのアーキテクチャ・アップグレード

Reddit r/LocalLLaMA / 2026/4/9

💬 オピニオンDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • リポジトリは、Mamba-1/Mamba-2 のシーケンス・ブロック・パラメータを、3つの特定の数学的な不一致(シーケンス反転、次元の崩壊、逆softplusによる再パラメータ化)に対処することで、同等の Mamba-3 アーキテクチャへ変換する重み移植手法を記述しています。
  • 変換スクリプト(mamba1_to_mamba3_converter.py)を使用し、Mamba-1 の in_proj 重みを切り出して並べ替え、D/dt_bias を Mamba-3 の nhead グルーピングへプーリングし、数値的挙動を維持するために inverse-softplus 変換を通じて dt_bias を再マッピングします。
  • 厳格な 12GB VRAM 制限内で学習するために、リカバリ・トレーナー(mamba3_recovery_trainer.py)は、グラフを即座に解放してメモリの急増を抑えるための、サンプルごとのマイクロ・バックワード手順を含む独自のメモリ節約型学習メカニズムに依存しています。
  • 学習パイプラインはさらに、移植された「連想メモリ」重みの 99% を凍結し、フェーズAで新たに導入された Mamba-3 のゲート・パラメータのみをアンフリーズして選択的リカバリを行うことで、計算量とメモリ使用量を削減します。

このリポジトリには、Mamba-1/Mamba-2 のアーキテクチャから重みを直接 Mamba-3 のゲートへ構造的に移植することで、スクラッチから学習回避を行うための手法とスクリプトが含まれています。

世代間で生じる数学的な不整合を処理し、厳格な 12GB VRAM の上限の範囲内で、Mamba-3 モデルを再びコヒーレンス(整合性)に戻せる二段階の構造復旧トレーニング・パイプラインを提供します。

手法(Methodology)

Mamba 1 から Mamba 3 へシーケンスブロックを移植する際、モデルが純粋なガラクタを出力しないように、3 つの重要な数学的な不一致を解決する必要があります:

1. [x, z] と [z, x] のシーケンス反転

  • 問題: Mamba-1 の in_proj は次元をメイン分岐(x)とゲーティング分岐(z)に分割します。Mamba-3 は [z, x] を想定しています。重みを盲目的にコピーすると、ネットワークの順伝播ロジックが物理的に逆転します。
  • 解決策: mamba1_to_mamba3_converter.py スクリプトは、in_proj の重み行列を d_inner で正確にスライスし、注入前に上半分と下半分を反転させることで、数学的に変換します。

2. 次元の崩壊(dt_bias, D)

  • 問題: Mamba-1 は構造的な D(スキップ接続)と dt_bias を、シーケンス長全体にわたってスケールします。Mamba-3 はそれらを、特定のサイズに調整された nheads のヘッダー(グループ)へプールします。
  • 解決策: スクリプトは、元の構造信号のスケールを保持するために、(例:5120 の塊を 64 のプールへ平均化するなどの)能動的な次元プーリング処理を実行します。

3. 逆 Softplus の再パラメータ化

  • 問題: Mamba-3 のカーネル変数には、特定のスケーリング・ロジックが必要です。バイアスの生値は、Triton の softplus 活性化層を通すと異なる対応関係になります。
  • 解決策: スクリプトは、変換された dt_bias の値に対して torch.log(torch.exp(weights) - 1.0) をマッピングし、数値的な等価性を維持します。

12GB VRAM 最適化

2.8B モデルを通常どおり学習するには、通常 ~18GB VRAM が必要です。標準的なアクティベーション・チェックポイントは、カスタムの Mamba-3 Triton カーネルと衝突することが多いため、mamba3_recovery_trainer.py では 2 つの方法で VRAM を最適化しています:

  1. サンプル単位のマイクロ・バックワーズ: バッチされたブロックに対して loss.backward() をそのまま実行する代わりに、ループを次のように落とし込みます:for sample in batch: loss.backward() graph.free() 勾配は安全に蓄積されますが、グラフはステップごとに即座に解放されるため、メモリの急増が抑えられます。
  2. フェーズ A の選択的フリーズ: 「連想記憶(associative memory)」を表す移植済みモデル重みの 99% を凍結し、追加されたばかりの Mamba-3 パラメータ・ゲートのみをアンフリーズします。

復旧(Recovery)パイプライン

移植されたモデルは、話す方法を忘れてしまった知的なエンジンのように振る舞います。復旧パイプラインは、新しいゲートを古いロジックに適応させます。

  • フェーズ A(150 ステップ): 2.8B モデル内のすべてを凍結し、統合されたばかりの Mamba-3 固有のゲート(B_bias, C_bias など)のみを有効にします。ゲートがレガシー行列にキャリブレーションされるにつれて、損失は急速に崩れ落ちます。
  • フェーズ B(>1000 ステップ): モデルは Low-Rank Adapter(LoRA)行列を出力にクリーンに注入し、推論を解放して能力を安定化させます。

使い方(Usage)

  1. ベースの Mamba .safetensors または .bin チェックポイントを正しいディレクトリに配置します。
  2. python mamba1_to_mamba3_converter.py を実行して、最初の移植済みシェル・チェックポイントを作成します。
  3. python mamba3_recovery_trainer.py を実行して、Phase A / Phase B のトレーニングループによりモデルのアーキテクチャを構造的に修復します。 https://github.com/batteryphil/mamba1and2-to-3.git
submitted by /u/Just-Ad-6488
[link] [comments]