I have been building a JAX based CFD framework for differentiable Navier Stokes simulation inside ML loops such as inverse design and learned closures.
The goal is to keep the full solver stack differentiable so it can sit inside optimisation and learning pipelines.
Design choices:
- Fully JAX native with no external dependencies
- CPU first vectorized implementation
- End to end differentiability through velocity, pressure, and vorticity fields
- Navier Stokes (projection method) and LBM (D2Q9) support
- Brinkman style forcing with smooth masks for geometry handling
Currently:
- 2D incompressible Navier Stokes solver using projection and pressure correction
- LBM solver integrated into the same framework
- Performance is CPU bound and grid dependent
- ~560 FPS at 128x128
- ~300 FPS at 512x96
- Differentiable flow fields throughout the pipeline
- Hooks for neural operators and learned corrections inside the solver loop
Here is the true value:
- Inverse design where geometry maps to flow and gradients propagate back to geometry
- Learning turbulence or residual closures directly in the solver
- Using CFD as a differentiable data generator for ML systems
- Hybrid physics and learned models without breaking gradient flow
Most CFD and ML pipelines still treat the solver as a black box, which makes gradient based design difficult or impossible.
AeroJAX is an attempt to keep the physics structure intact while making the entire pipeline differentiable.
[link] [comments]


