新しいオプティマイザ「Rose」—低VRAMで使いやすく高い結果、Apache 2.0

Reddit r/MachineLearning / 2026/4/24

📰 ニュースDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • PyTorch用の新しいオプティマイザ「Rose」がリリースされ、著者が数年かけて独自に開発したもので、著者の母への想いにちなんで名付けられています。
  • Roseはステートレス設計で、8-bit AdamWよりも少ないメモリ使用を目指しており、オプティマイザの状態を最小化することでVRAM削減につなげるとされています。
  • 著者は、Roseが高速な収束、低いVRAM使用、優れた汎化性能を持つと主張し、実際に試して結果を共有してほしいと呼びかけています。
  • ベンチマークは誤解を招く可能性があるため、Roseは学習損失が高くても検証損失が低くなる場合があり、最終的には学習済みモデルの出力で判断すべきだと注意しています。
  • RoseはApache 2.0ライセンスで提供され、GitHub上でコードと詳細情報が公開されています。

こんにちは、世界!私はここ数年、自分自身で研究し、開発してきた新しいPyTorchのオプティマイザを最近リリースしました。これは、私の母がAIでの発見や進捗について語りを聞くのが大好きだったことを記念して、「Rose」と名付けています。

技術的な詳細にはあまり踏み込みません(GitHubリポジトリで読めます)が、主な利点は次の通りです:

  • ステートレスです。つまり、8-bit AdamWよりもさらにメモリ使用量が少なくなります。一時的な作業メモリがなければ、メモリ使用量は素のバニラSGD(without)(モメンタムなし)と同程度まで下がります。
  • 高速な収束、低いVRAM使用量、そして優れた汎化性能。ええ、知っています……信じられないくらい良さそうに聞こえるかもしれません。ぜひ自分で試してみて、あなたの感想を教えてください。良い点でも悪い点でも、みなさんの体験談をぜひ聞きたいです。
  • Apache 2.0ライセンス

コードと詳細情報は以下で見つかります: https://github.com/MatthewK78/Rose

ベンチマークは時に誤解を招くことがあります。たとえば、Roseの方がAdamよりも学習損失が高い場合がある一方で、検証損失はRoseの方が低いことがあります。最終的に本当に重要なのは、学習済みモデルの実際の出力です。そしてそれですら、主観が入り得ます。ぜひ自分で試してみて、各自で結論を出してください。とはいえ、ここではいくつかの簡単なベンチマークを示します。

MNIST学習(同じシード):

[Rose] lr=3e-3, デフォルトのハイパーパラメータ text Epoch 1: avg loss 0.0516, acc 9827/10000 (98.27%) Epoch 2: avg loss 0.0372, acc 9874/10000 (98.74%) Epoch 3: avg loss 0.0415, acc 9870/10000 (98.70%) Epoch 4: avg loss 0.0433, acc 9876/10000 (98.76%) Epoch 5: avg loss 0.0475, acc 9884/10000 (98.84%) Epoch 6: avg loss 0.0449, acc 9892/10000 (98.92%) Epoch 7: avg loss 0.0481, acc 9907/10000 (99.07%) Epoch 8: avg loss 0.0544, acc 9918/10000 (99.18%) Epoch 9: avg loss 0.0605, acc 9901/10000 (99.01%) Epoch 10: avg loss 0.0668, acc 9904/10000 (99.04%) Epoch 11: avg loss 0.0566, acc 9934/10000 (99.34%) Epoch 12: avg loss 0.0581, acc 9929/10000 (99.29%) Epoch 13: avg loss 0.0723, acc 9919/10000 (99.19%) Epoch 14: avg loss 0.0845, acc 9925/10000 (99.25%) Epoch 15: avg loss 0.0690, acc 9931/10000 (99.31%)

[AdamW] lr=2.5e-3, デフォルトのハイパーパラメータ text Epoch 1: avg loss 0.0480, acc 9851/10000 (98.51%) Epoch 2: avg loss 0.0395, acc 9871/10000 (98.71%) Epoch 3: avg loss 0.0338, acc 9887/10000 (98.87%) Epoch 4: avg loss 0.0408, acc 9884/10000 (98.84%) Epoch 5: avg loss 0.0369, acc 9896/10000 (98.96%) Epoch 6: avg loss 0.0332, acc 9897/10000 (98.97%) Epoch 7: avg loss 0.0344, acc 9897/10000 (98.97%) Epoch 8: avg loss 0.0296, acc 9910/10000 (99.10%) Epoch 9: avg loss 0.0356, acc 9892/10000 (98.92%) Epoch 10: avg loss 0.0324, acc 9911/10000 (99.11%) Epoch 11: avg loss 0.0334, acc 9910/10000 (99.10%) Epoch 12: avg loss 0.0323, acc 9916/10000 (99.16%) Epoch 13: avg loss 0.0310, acc 9918/10000 (99.18%) Epoch 14: avg loss 0.0292, acc 9930/10000 (99.30%) Epoch 15: avg loss 0.0295, acc 9925/10000 (99.25%)


メモリオーバーヘッド(パラメータに対するオプティマイザ状態の相対量):

  • Rose: 0×
  • SGD(モメンタムなし):0×
  • Adafactor: 約0.5-1×(factorized)
  • SGD(モメンタムあり):1×
  • AdaGrad: 1×
  • Lion: 1×
  • Adam/AdamW/RAdam/NAdam: 2×
  • Sophia: 約2×
  • Prodigy: 約2-3×

OpenAIにはGitHubリポジトリopenai/parameter-golfでチャレンジがあります。何も変更せずに簡単なテストを実行すると、次の結果になります:

[Adam] final_int8_zlib_roundtrip_exact val_loss:3.79053424 val_bpb:2.24496788

train_gpt.pyファイルの中で、単にoptimizer_tokoptimizer_scalarを置き換えるだけで、この結果になります:

[Rose] final_int8_zlib_roundtrip_exact val_loss:3.74317755 val_bpb:2.21692059

optimizer_muonはそのままにしておきました。補足ですが、私はMuonの性能に直接競り合おうとしているわけではありません。ただし、Muonの大きな問題は、2次元パラメータにしか対応しておらず、残りを埋めるためにAdamなどの他のオプティマイザに頼っている点です。また、メモリもより多く使います。私のRoseオプティマイザの最大の強みの一つは、極端に低いメモリ使用量です。

興味があれば、もう少し詳しく見てみましょう(warmup stepsは削除):

[Adam] text world_size:2 grad_accum_steps:4 sdp_backends:cudnn=False flash=True mem_efficient=False math=False attention_mode:gqa num_heads:8 num_kv_heads:4 tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 train_batch_tokens:16384 train_seq_len:1024 iterations:200 warmup_steps:20 max_wallclock_seconds:600.000 seed:1337 < 20 warmup steps were here > step:1/200 train_loss:6.9441 train_time:156ms step_avg:155.60ms step:2/200 train_loss:18.0591 train_time:283ms step_avg:141.70ms step:3/200 train_loss:12.4893 train_time:373ms step_avg:124.43ms step:4/200 train_loss:7.8984 train_time:461ms step_avg:115.37ms step:5/200 train_loss:6.7623 train_time:552ms step_avg:110.46ms step:6/200 train_loss:6.7258 train_time:640ms step_avg:106.74ms step:7/200 train_loss:6.5040 train_time:729ms step_avg:104.14ms step:8/200 train_loss:6.5109 train_time:817ms step_avg:102.16ms step:9/200 train_loss:6.1916 train_time:906ms step_avg:100.61ms step:10/200 train_loss:6.0549 train_time:994ms step_avg:99.45ms step:200/200 train_loss:3.8346 train_time:18892ms step_avg:94.46ms step:200/200 val_loss:3.7902 val_bpb:2.2448 train_time:18893ms step_avg:94.46ms peak memory allocated: 586 MiB reserved: 614 MiB Serialized model: 67224983 bytes Code size: 48164 bytes Total submission size: 67273147 bytes Serialized model int8+zlib: 11374265 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) Total submission size int8+zlib: 11422429 bytes final_int8_zlib_roundtrip val_loss:3.7905 val_bpb:2.2450 eval_time:67924ms final_int8_zlib_roundtrip_exact val_loss:3.79053424 val_bpb:2.24496788

[Rose]

optimizer_tok = Rose([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], lr=token_lr, stabilize=False, compute_dtype=None)

optimizer_scalar = Rose([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], lr=args.scalar_lr, stabilize=False, compute_dtype=None)

text world_size:2 grad_accum_steps:4 sdp_backends:cudnn=False flash=True mem_efficient=False math=False attention_mode:gqa num_heads:8 num_kv_heads:4 tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 train_batch_tokens:16384 train_seq_len:1024 iterations:200 warmup_steps:20 max_wallclock_seconds:600.000 seed:1337 < 20 warmup steps were here > step:1/200 train_loss:6.9441 train_time:173ms step_avg:173.15ms step:2/200 train_loss:6.4086 train_time:305ms step_avg:152.69ms step:3/200 train_loss:6.2232 train_time:433ms step_avg:144.21ms step:4/200 train_loss:6.1242 train_time:557ms step_avg:139.24ms step:5/200 train_loss:5.9950 train_time:681ms step_avg:136.23ms step:6/200 train_loss:6.0386 train_time:806ms step_avg:134.38ms step:7/200 train_loss:5.9189 train_time:933ms step_avg:133.22ms step:8/200 train_loss:5.8817 train_time:1062ms step_avg:132.78ms step:9/200 train_loss:5.5375 train_time:1192ms step_avg:132.43ms step:10/200 train_loss:5.4599 train_time:1322ms step_avg:132.25ms step:200/200 train_loss:3.7445 train_time:24983ms step_avg:124.91ms step:200/200 val_loss:3.7390 val_bpb:2.2144 train_time:24984ms step_avg:124.92ms peak memory allocated: 584 MiB reserved: 612 MiB Serialized model: 67224983 bytes Code size: 48449 bytes Total submission size: 67273432 bytes Serialized model int8+zlib: 11209724 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) Total submission size int8+zlib: 11258173 bytes final_int8_zlib_roundtrip val_loss:3.7432 val_bpb:2.2169 eval_time:65817ms final_int8_zlib_roundtrip_exact val_loss:3.74317755 val_bpb:2.21692059


AdamWとRoseの間での学習の視覚的比較: https://www.reddit.com/r/StableDiffusion/comments/1ss85os/training_comparison_adamw_on_the_left_rose_on_the/


[更新ルール] ```text

1. 分離された重み減衰(decoupled weight decay)

θ ← (1 − η_wd · λ) · θ

2. 勾配の集中化(任意)

g̃_i ← g_i − mean(g_i) # 先頭以外の全軸にわたる平均

3. スライスごとの範囲

R_i ← |max(g̃_i)| − min(g̃_i) # スライスごとに1つのスカラー

4. CV信頼ゲーティング(任意)

μ_R ← mean(R), σ_R ← std(R) # 全スライスにわたって τ ← μ_R / (σ_R + μ_R) # 同等的に 1/(1 + CV) D_i ← (1 − τ) · μ_R + τ · R_i # グローバルとローカルの間を補間

5. 更新

θ ← θ − η · g̃ / D ```

投稿者 /u/ECF630
[リンク] [コメント]