AeroJAX:JAXネイティブのCFD(微分可能でエンドツーエンド)、CPUで128×128あたり約560FPS

Reddit r/MachineLearning / 2026/4/29

💬 オピニオンDeveloper Stack & InfrastructureSignals & Early TrendsIdeas & Deep AnalysisModels & Research

要点

  • AeroJAXは、MLの最適化ループ内で使えるように、Navier–Stokes系のCFDソルバ全体をエンドツーエンドで微分可能に保つことを目的としたJAXネイティブのCFDフレームワークだと述べられています。
  • 速度・圧力・渦度といった主要な流体変数に対して完全な微分可能性を重視しており、逆設計や学習ベースのクロージャ/残差モデルの学習に向けて、シミュレーションを通じて勾配を逆伝播できる点が強調されています。
  • AeroJAXは、2次元非圧縮Navier–Stokes(投影法+圧力補正)に加えて、同一フレームワークにLBM(D2Q9)も統合し、Brinkman型の強制項と滑らかなジオメトリマスクを用いると説明しています。
  • 実装はCPU最優先のベクトル化で、現在はグリッド依存の性能として128×128で約560FPS、512×96で約300FPSなどの数値が報告されています。
  • 価値としては、CFDを微分可能なデータ生成器にし、勾配の流れを断ち切らない形で「物理×学習」のハイブリッドモデルを可能にする点が挙げられており、従来の“ブラックボックス”的なCFD/MLパイプラインとの差が示されています。

私は、逆設計や学習済みクロージャーといったMLループの中での微分可能Navier Stokesシミュレーションのために、JAXベースのCFDフレームワークを構築してきました。

最適化や学習パイプラインの内部に組み込めるように、完全なソルバースタックを微分可能に保つことが目標です。

設計上の選択:

  • 外部依存なしの完全なJAXネイティブ
  • CPUを最初にしたベクトル化実装
  • 速度・圧力・渦度の各場を通じたエンドツーエンドの微分可能性
  • Navier Stokes(投影法)およびLBM(D2Q9)サポート
  • 幾何形状の取り扱いのための滑らかなマスクを用いたBrinkmanスタイルの強制項

現在:

  • 投影と圧力補正を用いた2D非圧縮Navier Stokesソルバー
  • LBMソルバーを同一フレームワークに統合
  • 性能はCPU制約で、グリッドに依存
    • 128x128で約560 FPS
    • 512x96で約300 FPS
  • パイプライン全体を通した微分可能な流れ場
  • ソルバーのループ内部でニューラル演算子および学習済み補正を行うためのフック

本当の価値はここにあります:

  • 逆設計:幾何形状が流れにマッピングされ、勾配が幾何形状へと逆伝播する
  • 乱流または残差クロージャーを、直接ソルバー内で学習する
  • CFDを、MLシステム向けの微分可能なデータ生成器として使う
  • 勾配の流れを壊さずに、物理と学習モデルをハイブリッド化する

多くのCFDおよびMLパイプラインでは、ソルバーを依然としてブラックボックスとして扱っています。そのため、勾配に基づく設計が難しい、あるいは不可能になります。

AeroJAXは、物理の構造を保ったまま、パイプライン全体を微分可能にする試みです。

submitted by /u/LackSome307
[リンク] [コメント]

AeroJAX:JAXネイティブのCFD(微分可能でエンドツーエンド)、CPUで128×128あたり約560FPS | AI Navigate