Parax: JAX + Equinox におけるパラメトリックモデリング [P]

Reddit r/MachineLearning / 2026/4/9

💬 オピニオンDeveloper Stack & InfrastructureTools & Practical Usage

要点

  • Parax は Equinox 上に構築された Python のアドオンで、JAX における「パラメータ優先」のモデリングワークフローを対象としています。
  • このライブラリは、eqx.Module を継承する Parax.Parameter と Parax.Module を導入し、パラメータを固定としてマークしたり事前分布を保存したりするなど、パラメータにメタデータを付与するのを支援します。
  • Equinox の不変(イミュータブル)設計原則を維持しながら、深いモジュール階層内のパラメータを検査したり操作したりするためのユーティリティを提供します。
  • このプロジェクトにはドキュメントと例が含まれており、Parax を JAX/Equinox で構築される科学アプリケーションで再利用可能なツールになり得るものとして位置付けています。

こんにちは、みなさん!

Pythonプロジェクト Parax を共有したくて投稿しました。これは Equinox ライブラリの上に構築したアドオンで、JAXにおける「パラメータ優先」のモデリングに対応しています。

私たちの科学アプリケーションでは、固定としてマークしたり、事前確率分布を紐づけたりといった具合に、パラメータオブジェクトにメタデータを付与する必要があることがよくありました。さらに、非常に深い階層の中でこれらのパラメータを操作する必要も多く、状況によっては eqx.tree_at を使うと直感的でないことがあります。

そこで、Paraxを開発しました。これはparax.Parameterparax.Module(どちらも eqx.Module を継承)に加えて、いくつかの補助ユーティリティを提供します。これにより、Equinoxの不変(immutable)の原則に従いながらも、よりオブジェクト指向的なモデルの検査・操作のアプローチが可能になります。

いくつかの ドキュメント と、いくつかの例もあります。もしかすると、どなたか他の方にとって役に立つかもしれません! :)

それでは、
Gary

submitted by /u/gvcallen
[link] [comments]