PyTorchでゼロから分散学習の並列化を実装する、小さな教育用リポジトリを作りました:
https://github.com/shreyansh26/pytorch-distributed-training-from-scratch
高レベルの抽象化を使うのではなく、コードは順伝播/逆伝播のロジックとコレクティブ(集団通信)を明示的に記述しているので、アルゴリズムをそのまま直接見ることができます。
モデルは意図的に、合成タスク上で繰り返すだけの2つの行列積(2-matmul)のMLPブロックで構成されているため、主に研究対象は通信パターンです。
これは主に、巨大なフレームワークを掘り下げることなく、分散学習の数式をそのまま実行可能なコードへ対応付けたい人向けに作りました。
[リンク] [コメント]




