Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ jobs:
name: Build and publish documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.x"

Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
name: Build distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.x"

Expand All @@ -21,7 +21,7 @@ jobs:
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v5
with:
name: python-package-distributions
path: dist/
Expand All @@ -41,7 +41,7 @@ jobs:

steps:
- name: Download all the dists
uses: actions/download-artifact@v4
uses: actions/download-artifact@v5
with:
name: python-package-distributions
path: dist/
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ jobs:
timeout-minutes: 10

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/checkout@v6
- uses: actions/setup-python@v6
with:
python-version: "3.10"
python-version: "3.x"

- name: Install dependencies
run: |
Expand Down
29 changes: 19 additions & 10 deletions flowjax/experimental/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import equinox as eqx
import jax
import jax.random as jr
import paramax
from jaxtyping import Array, ArrayLike

Expand All @@ -18,7 +17,14 @@

try:
import numpyro
except ImportError as e:
from packaging.version import Version

if Version(numpyro.__version__) < Version("0.20.0"):
raise ImportError(
f"numpyro version must be >= 0.20.0, got {numpyro.__version__}."
)

except ModuleNotFoundError as e:
e.add_note(
"Note, in order to interface with numpyro, it must be installed. Please see "
"https://num.pyro.ai/en/latest/getting_started.html#installation",
Expand Down Expand Up @@ -52,7 +58,6 @@ def log_prob(self, value, intermediates=None):
y = value

for i, transform in enumerate(reversed(self.transforms)):

if isinstance(transform, _BijectionToNumpyro) and intermediates is None:
# Compute inv and log det in one
inv_transform = _BijectionToNumpyro(
Expand Down Expand Up @@ -182,7 +187,7 @@ def __init__(
condition = arraylike_to_array(condition, "condition")

self._condition = condition
self.support = _RealNdim(dist.ndim)
self._support = _RealNdim(dist.ndim)
batch_shape = _get_batch_shape(condition, dist.cond_shape)
super().__init__(batch_shape, dist.shape)

Expand All @@ -191,10 +196,6 @@ def condition(self):
return jax.lax.stop_gradient(self._condition)

def sample(self, key, sample_shape=()):
# TODO remove when old-style keys fully deprecated
if not jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key):
key = jr.wrap_key_data(key)

return self.dist.sample(key, sample_shape, self.condition)

def log_prob(self, value):
Expand Down Expand Up @@ -243,8 +244,8 @@ def __init__(
domain = _RealNdim(len(bijection.shape))
if codomain is None:
codomain = _RealNdim(len(bijection.shape))
self.domain = domain
self.codomain = codomain
self._domain = domain
self._codomain = codomain
self._argcheck_domains()

def __call__(self, x):
Expand All @@ -265,6 +266,14 @@ def call_with_intermediates(self, x):
def condition(self):
return jax.lax.stop_gradient(self._condition)

@property
def domain(self):
return self._domain

@property
def codomain(self):
return self._codomain

def tree_flatten(self):
return (self.bijection, self._condition, self.domain, self.codomain), (
("bijection", "_condition", "domain", "codomain"),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dev = [
"sphinx-autodoc-typehints",
"nbsphinx",
"ipython",
"numpyro",
"numpyro>=0.20.0",
]

[build-system]
Expand Down