AeroJAX: JAX-native CFD, differentiable end-to-end. ~560 FPS at 128x128 on CPU [P]

Reddit r/MachineLearning / 4/29/2026

💬 OpinionDeveloper Stack & InfrastructureSignals & Early TrendsIdeas & Deep AnalysisModels & Research

Key Points

  • The article presents AeroJAX, a JAX-native CFD framework designed to keep a differentiable Navier–Stokes/CFD solver fully end-to-end usable inside ML optimization loops.
  • It emphasizes full differentiability across core flow variables (velocity, pressure, and vorticity), enabling gradients to propagate back through the simulation for inverse design and learned closure/residual models.
  • AeroJAX supports both a 2D incompressible Navier–Stokes solver (projection method with pressure correction) and an integrated LBM option (D2Q9), with Brinkman-style forcing and smooth geometry masks.
  • The current implementation is CPU-focused and vectorized, reporting performance such as ~560 FPS at 128×128 and ~300 FPS at 512×96, with grid-dependent speed.
  • The author positions the main value as making CFD a differentiable data generator and enabling hybrid physics + learned models without breaking gradient flow, unlike common “black-box” CFD/ML pipelines.

I have been building a JAX based CFD framework for differentiable Navier Stokes simulation inside ML loops such as inverse design and learned closures.

The goal is to keep the full solver stack differentiable so it can sit inside optimisation and learning pipelines.

Design choices:

  • Fully JAX native with no external dependencies
  • CPU first vectorized implementation
  • End to end differentiability through velocity, pressure, and vorticity fields
  • Navier Stokes (projection method) and LBM (D2Q9) support
  • Brinkman style forcing with smooth masks for geometry handling

Currently:

  • 2D incompressible Navier Stokes solver using projection and pressure correction
  • LBM solver integrated into the same framework
  • Performance is CPU bound and grid dependent
    • ~560 FPS at 128x128
    • ~300 FPS at 512x96
  • Differentiable flow fields throughout the pipeline
  • Hooks for neural operators and learned corrections inside the solver loop

Here is the true value:

  • Inverse design where geometry maps to flow and gradients propagate back to geometry
  • Learning turbulence or residual closures directly in the solver
  • Using CFD as a differentiable data generator for ML systems
  • Hybrid physics and learned models without breaking gradient flow

Most CFD and ML pipelines still treat the solver as a black box, which makes gradient based design difficult or impossible.

AeroJAX is an attempt to keep the physics structure intact while making the entire pipeline differentiable.

submitted by /u/LackSome307
[link] [comments]