diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 72be6474..73454d5e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -9,8 +9,8 @@ jobs: name: Build and publish documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: python-version: "3.x" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 82ea088e..a55b31cc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -10,9 +10,9 @@ jobs: name: Build distribution runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.x" @@ -21,7 +21,7 @@ jobs: - name: Build a binary wheel and a source tarball run: python3 -m build - name: Store the distribution packages - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: python-package-distributions path: dist/ @@ -41,7 +41,7 @@ jobs: steps: - name: Download all the dists - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: python-package-distributions path: dist/ diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c1eb8d08..14dd2c66 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,10 +15,10 @@ jobs: timeout-minutes: 10 steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: - python-version: "3.10" + python-version: "3.x" - name: Install dependencies run: | diff --git a/flowjax/experimental/numpyro.py b/flowjax/experimental/numpyro.py index d000d41a..f34983b0 100644 --- a/flowjax/experimental/numpyro.py +++ b/flowjax/experimental/numpyro.py @@ -8,7 +8,6 @@ import equinox as eqx import jax -import jax.random as jr import paramax from jaxtyping import Array, ArrayLike @@ -18,7 +17,14 @@ try: import numpyro -except ImportError as e: + from packaging.version import Version + + if Version(numpyro.__version__) < Version("0.20.0"): + raise ImportError( + f"numpyro version must be >= 0.20.0, got {numpyro.__version__}." + ) + +except ModuleNotFoundError as e: e.add_note( "Note, in order to interface with numpyro, it must be installed. Please see " "https://num.pyro.ai/en/latest/getting_started.html#installation", @@ -52,7 +58,6 @@ def log_prob(self, value, intermediates=None): y = value for i, transform in enumerate(reversed(self.transforms)): - if isinstance(transform, _BijectionToNumpyro) and intermediates is None: # Compute inv and log det in one inv_transform = _BijectionToNumpyro( @@ -182,7 +187,7 @@ def __init__( condition = arraylike_to_array(condition, "condition") self._condition = condition - self.support = _RealNdim(dist.ndim) + self._support = _RealNdim(dist.ndim) batch_shape = _get_batch_shape(condition, dist.cond_shape) super().__init__(batch_shape, dist.shape) @@ -191,10 +196,6 @@ def condition(self): return jax.lax.stop_gradient(self._condition) def sample(self, key, sample_shape=()): - # TODO remove when old-style keys fully deprecated - if not jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): - key = jr.wrap_key_data(key) - return self.dist.sample(key, sample_shape, self.condition) def log_prob(self, value): @@ -243,8 +244,8 @@ def __init__( domain = _RealNdim(len(bijection.shape)) if codomain is None: codomain = _RealNdim(len(bijection.shape)) - self.domain = domain - self.codomain = codomain + self._domain = domain + self._codomain = codomain self._argcheck_domains() def __call__(self, x): @@ -265,6 +266,14 @@ def call_with_intermediates(self, x): def condition(self): return jax.lax.stop_gradient(self._condition) + @property + def domain(self): + return self._domain + + @property + def codomain(self): + return self._codomain + def tree_flatten(self): return (self.bijection, self._condition, self.domain, self.codomain), ( ("bijection", "_condition", "domain", "codomain"), diff --git a/pyproject.toml b/pyproject.toml index 7014924a..d4bb290c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dev = [ "sphinx-autodoc-typehints", "nbsphinx", "ipython", - "numpyro", + "numpyro>=0.20.0", ] [build-system]