
- 中級
- 49分
- 7本のビデオレッスン
- 4つのコード例
- 講師: Chris Achard
学習内容
GoogleのGemini、Veo、Nano Bananaモデルの背後にあるオープンソースライブラリJAXを使用し、2000万パラメータのGPT-2スタイルの言語モデルをゼロから構築します。
JAXのコアプライミティブ(自動微分、JITコンパイル、ベクトル化マッピング)を学び、ニューラルネットワークを効率的に定義・訓練・チェックポイント化する方法を理解します。
事前学習済みのMiniGPTモデルを読み込み、チャットインターフェースを通じて推論を実行。データ前処理から訓練、訓練済みLLMによるテキスト生成までのフルワークフローを体験します。
このコースについて
Googleとの提携による短期コース「JAXでLLMを構築・訓練する」を紹介します。本コースはGoogleのTPUソフトウェアチームのデベロッパーリレーションエンジニア、Chris Achardが担当します。
JAXはGoogleがGeminiをはじめとする最先端モデルの構築・訓練に用いるオープンソースの数値計算ライブラリで、NumPyに似ていますが、自動微分、JITコンパイル、多数のCPU/GPU/TPUにまたがる訓練拡張機能を備えています。本コースでは、JAXを使って言語モデルをゼロから構築・訓練して学びます。
20百万パラメータのMiniGPTスタイルのLLMを完全に実装します。アーキテクチャの定義、データの読み込み・前処理、訓練ループの実装、チェックポイント保存、完成モデルとのグラフィカルインターフェースによる対話などを行います。途中ではFlax/NNX(ニューラルネットワーク層)、Grain(データ読み込み)、Optax(最適化)、Orbax(チェックポイント管理)といったJAXエコシステムの主要ツールを扱います。
具体的には、以下を学習します。
- JAXの基礎となる自動微分、JITコンパイル、ベクトル化実行の概念と、NumPy、PyTorch、TensorFlowとの比較を通じて機械学習の全体像を理解します。
- JAXとFlax/NNXを使い、MiniGPTスタイルの言語モデルのアーキテクチャを構築。トークン埋め込みやトランスフォーマーブロックを実装し、完全かつ訓練可能なモデルを実現します。
- ミニストーリーのデータセットを読み込み・前処理。トークナイゼーション、バッチ処理、JAXの関数型実行モデルに合ったデータ構造設計を行います。
- 損失計算、Optaxによる勾配適用、JAX変換を用いた効率的な訓練ループを実装し、最後にOrbaxチェックポイントでモデルを保存します。
- 事前学習済みMiniGPTモデルを読み込み、チャットインターフェース経由で推論を実行しストーリー生成。構築・訓練・展開の全工程を体験します。
MiniGPTの構築・訓練手順は、GoogleがGeminiなどの強力なモデル開発で用いる基盤的な手順と同じです。本コースでモダンなLLM開発に必須のツールと技術を実践的に学べます。
対象者
大型言語モデルの基礎的な構築・訓練方法を理解したい開発者や機械学習実践者。Pythonや基本的な機械学習概念の知識があることが推奨されます。
コース概要
7レッスン・4コード例イントロダクション
ビデオ・3分
JAXの概要
ビデオ・6分
アーキテクチャ構築
コード例付きビデオ・10分
データ読み込み
コード例付きビデオ・6分
訓練と保存
コード例付きビデオ・8分
最終MiniGPT
コード例付きビデオ・3分
まとめ
ビデオ・1分
クイズ
読み物・10分
講師
JAXでLLMを構築・訓練する
- 中級
- 49分
- 7本のビデオレッスン
- 4つのコード例
- 講師: Chris Achard
DeepLearning.AI Proではクイズやプロジェクトなどの追加学習機能が含まれます。ぜひ体験してみてください
生成AIについてもっと学びたいですか?
DeepLearning.AIからのキュレーションしたAIニュース、コース、イベント、そしてAndrewの考えを通じて継続的に学習していきましょう!

