Parax: Parametric Modeling in JAX + Equinox [P]

Reddit r/MachineLearning / 4/9/2026

💬 OpinionDeveloper Stack & InfrastructureTools & Practical Usage

Key Points

  • Parax is a Python add-on built on Equinox that targets “parameter-first” modeling workflows in JAX.
  • The library introduces Parax.Parameter and Parax.Module (both inheriting from eqx.Module) to help attach metadata to parameters, such as marking parameters as fixed or storing prior distributions.
  • It provides utilities for inspecting and manipulating parameters within deep module hierarchies while keeping Equinox’s immutable design principles.
  • The project includes documentation and examples, positioning Parax as a potentially reusable tool for scientific applications built with JAX/Equinox.

Hi everyone!

Just wanted to share my Python project Parax - an add-on on top of the Equinox library catering for parameter-first modeling in JAX.

For our scientific applications, we found that we often needed to attach metadata to our parameter objects, such as marking them as fixed or attached a prior probability distribution. Further, we often needed to manipulate these parameters in very deep hierarchies, which sometimes can be unintuitive using eqx.tree_at.

We therefore developed Parax, which providesparax.Parameter and parax.Module (that both inherit from eqx.Module) as well as a few helper utilities. These provide a more object-orientated model inspection and manipulation approach, while still following Equinox's immutable principles.

There is some documentation along with a few examples. Perhaps the package is of use to someone else out there! :)

Cheers,
Gary

submitted by /u/gvcallen
[link] [comments]