From b759245a77d0b5c13427c2daa8a425dd84cd491b Mon Sep 17 00:00:00 2001 From: Daniel Ward Date: Sun, 19 Oct 2025 09:53:33 +0100 Subject: [PATCH] better_signatures --- flowjax/bijections/jax_transforms.py | 21 +++++++-------- flowjax/bijections/utils.py | 39 ++++++++++++++++------------ flowjax/distributions.py | 3 +-- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/flowjax/bijections/jax_transforms.py b/flowjax/bijections/jax_transforms.py index d3672c3c..56d283e3 100644 --- a/flowjax/bijections/jax_transforms.py +++ b/flowjax/bijections/jax_transforms.py @@ -37,6 +37,13 @@ class Scan(AbstractBijection): """ bijection: AbstractBijection + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + + def __init__(self, bijection: AbstractBijection): + self.bijection = bijection + self.shape = self.bijection.shape + self.cond_shape = self.bijection.cond_shape def transform_and_log_det(self, x, condition=None): def step(carry, bijection): @@ -56,14 +63,6 @@ def step(carry, bijection): (y, log_det), _ = _filter_scan(step, (y, 0), self.bijection, reverse=True) return y, log_det - @property - def shape(self): - return self.bijection.shape - - @property - def cond_shape(self): - return self.bijection.cond_shape - def _filter_scan(f, init, xs, *, reverse=False): params, static = eqx.partition(xs, filter_spec=eqx.is_array) @@ -151,6 +150,7 @@ class Vmap(AbstractBijection): bijection: AbstractBijection in_axes: tuple axis_size: int + shape: tuple[int, ...] cond_shape: tuple[int, ...] | None def __init__( @@ -173,6 +173,7 @@ def __init__( self.in_axes = (in_axes, 0, in_axes_condition) self.bijection = bijection self.axis_size = axis_size + self.shape = (self.axis_size, *self.bijection.shape) self.cond_shape = self.get_cond_shape(in_axes_condition) def vmap(self, f: Callable): @@ -192,10 +193,6 @@ def _inverse_and_log_det(bijection, x, condition): x, log_det = self.vmap(_inverse_and_log_det)(self.bijection, y, condition) return x, jnp.sum(log_det) - @property - def shape(self): - return (self.axis_size, *self.bijection.shape) - def get_cond_shape(self, cond_ax): if self.bijection.cond_shape is None or cond_ax is None: return self.bijection.cond_shape diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index a75333b5..d906a0fe 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -29,6 +29,13 @@ class Invert(AbstractBijection): """ bijection: AbstractBijection + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + + def __init__(self, bijection: AbstractBijection): + self.bijection = bijection + self.shape = self.bijection.shape + self.cond_shape = self.bijection.cond_shape def transform_and_log_det(self, x, condition=None): return self.bijection.inverse_and_log_det(x, condition) @@ -36,14 +43,6 @@ def transform_and_log_det(self, x, condition=None): def inverse_and_log_det(self, y, condition=None): return self.bijection.transform_and_log_det(y, condition) - @property - def shape(self): - return self.bijection.shape - - @property - def cond_shape(self): - return self.bijection.cond_shape - class Permute(AbstractBijection): """Permutation transformation. @@ -114,12 +113,24 @@ class Indexed(AbstractBijection): bijection: Bijection that is compatible with the subset of x indexed by idxs. idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the transformed portion. - shape: Shape of the bijection. Defaults to None. + shape: Shape of the bijection. """ bijection: AbstractBijection idxs: int | slice | Array | tuple shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + + def __init__( + self, + bijection: AbstractBijection, + idxs: int | slice | Array | tuple, + shape: tuple[int, ...], + ): + self.bijection = bijection + self.idxs = idxs + self.shape = shape + self.cond_shape = bijection.cond_shape def __check_init__(self): expected_shape = jnp.zeros(self.shape)[self.idxs].shape @@ -138,10 +149,6 @@ def inverse_and_log_det(self, y, condition=None): x, log_det = self.bijection.inverse_and_log_det(y[self.idxs], condition) return y.at[self.idxs].set(x), log_det - @property - def cond_shape(self): - return self.bijection.cond_shape - class Identity(AbstractBijection): """The identity bijection. @@ -174,6 +181,7 @@ class EmbedCondition(AbstractBijection): """ bijection: AbstractBijection + shape: tuple[int, ...] cond_shape: tuple[int, ...] embedding_net: Callable @@ -185,6 +193,7 @@ def __init__( ): self.bijection = bijection self.embedding_net = embedding_net + self.shape = self.bijection.shape self.cond_shape = raw_cond_shape def transform_and_log_det(self, x, condition=None): @@ -195,10 +204,6 @@ def inverse_and_log_det(self, y, condition=None): condition = self.embedding_net(condition) return self.bijection.inverse_and_log_det(y, condition) - @property - def shape(self): - return self.bijection.shape - class Reshape(AbstractBijection): """Wraps bijection methods with reshaping operations. diff --git a/flowjax/distributions.py b/flowjax/distributions.py index 64fbc84c..5613e496 100644 --- a/flowjax/distributions.py +++ b/flowjax/distributions.py @@ -209,7 +209,6 @@ def _get_sample_keys( sample_shape: tuple[int, ...], condition, ): - if not dtypes.issubdtype(key.dtype, dtypes.prng_key): raise TypeError("New-style typed JAX PRNG keys required.") @@ -339,6 +338,7 @@ class Transformed(AbstractTransformed): >>> bijection = Affine(1) >>> transformed = Transformed(normal, bijection) """ + base_dist: AbstractDistribution bijection: AbstractBijection @@ -348,7 +348,6 @@ def __init__(self, base_dist: AbstractDistribution, bijection: AbstractBijection self.bijection = bijection - class AbstractLocScaleDistribution(AbstractTransformed): """Abstract distribution class for affine transformed distributions."""