私は、NLLB-200 600Mモデルをファインチューニングする際にこの特定のバッチング問題に直面したことをきっかけに、dynabatch という小さなpytorch用サンプラを作りました。
RTX 5090で学習したところ、私が使えた最大の固定バッチサイズは8で、これ以上だとOOMになります。nvidia-smi , で学習しつつ監視していると、実際にGPUを強く負荷しているのはごく一部のバッチだけのように見えました。多くの時間、利用率はかなり低い状態でした。固定バッチサイズは、最長のソース/ターゲット例によって決まっている一方で、短い例はバッチあたりにより多くのサンプルを入れる余地があるのではないか、と推測しました。
そこで、シーケンス長が変わるにつれてバッチサイズも変えられるようにしてみました。アイデアの要点は以下です:
- トークン長で例を並べ替え、最長から始める
- 最初のバッチを「この条件で収まる最も難しいバッチ」とみなす
- 以降の短いバッチでは、より大きい候補バッチサイズを試す
- 最初のバッチに対するメモリ負荷を予測するために、小さなXGB回帰器を使う
- 安全なしきい値の下に収まる最大の候補を選ぶ
これは主にエンコーダ・デコーダモデル向け、特にMTではソース長がターゲット長の有用な代理指標になりやすいので、その文脈で考えています。デコーダのみモデルの最初のツールとしては使わないと思います。シーケンスパッキングのほうが勝ち筋だと思います。
私の学習ベンチマークでは、固定バッチ学習に比べて約3.3xのスループット改善が得られました。この数値は自分のセットアップに基づくものですが、一般的な主張として読むべきではないとも思っています。collab T4の生成ベンチマークでは、得られた改善は約1.06x - 1.21xにとどまりました
また、回帰器も経験的なもので、測定したGPUメモリ使用量から学習させているため、時々外れる可能性があり、モデル/トークナイザによって挙動が少し変わるかもしれません。ただし、過大に見積もってOOMになりそうな場合のフォールバックも追加しました。(興味がある方のために、回帰器の学習ノートブックも追加しました)
なので正直に言うと、これは特にデコーダのみの時代ではかなりニッチなツールだと思いますが、エンコーダ・デコーダのMTモデルで学習/生成をしている人の役に立てば嬉しいです。
Repo: https://github.com/bendangnuksung/dynabatch
PyPI: https://pypi.org/project/dynabatch/
[link] [comments]




