Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class Scan(AbstractBijection):
"""

bijection: AbstractBijection
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None

def __init__(self, bijection: AbstractBijection):
self.bijection = bijection
self.shape = self.bijection.shape
self.cond_shape = self.bijection.cond_shape

def transform_and_log_det(self, x, condition=None):
def step(carry, bijection):
Expand All @@ -56,14 +63,6 @@ def step(carry, bijection):
(y, log_det), _ = _filter_scan(step, (y, 0), self.bijection, reverse=True)
return y, log_det

@property
def shape(self):
return self.bijection.shape

@property
def cond_shape(self):
return self.bijection.cond_shape


def _filter_scan(f, init, xs, *, reverse=False):
params, static = eqx.partition(xs, filter_spec=eqx.is_array)
Expand Down Expand Up @@ -151,6 +150,7 @@ class Vmap(AbstractBijection):
bijection: AbstractBijection
in_axes: tuple
axis_size: int
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None

def __init__(
Expand All @@ -173,6 +173,7 @@ def __init__(
self.in_axes = (in_axes, 0, in_axes_condition)
self.bijection = bijection
self.axis_size = axis_size
self.shape = (self.axis_size, *self.bijection.shape)
self.cond_shape = self.get_cond_shape(in_axes_condition)

def vmap(self, f: Callable):
Expand All @@ -192,10 +193,6 @@ def _inverse_and_log_det(bijection, x, condition):
x, log_det = self.vmap(_inverse_and_log_det)(self.bijection, y, condition)
return x, jnp.sum(log_det)

@property
def shape(self):
return (self.axis_size, *self.bijection.shape)

def get_cond_shape(self, cond_ax):
if self.bijection.cond_shape is None or cond_ax is None:
return self.bijection.cond_shape
Expand Down
39 changes: 22 additions & 17 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,20 @@ class Invert(AbstractBijection):
"""

bijection: AbstractBijection
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None

def __init__(self, bijection: AbstractBijection):
self.bijection = bijection
self.shape = self.bijection.shape
self.cond_shape = self.bijection.cond_shape

def transform_and_log_det(self, x, condition=None):
return self.bijection.inverse_and_log_det(x, condition)

def inverse_and_log_det(self, y, condition=None):
return self.bijection.transform_and_log_det(y, condition)

@property
def shape(self):
return self.bijection.shape

@property
def cond_shape(self):
return self.bijection.cond_shape


class Permute(AbstractBijection):
"""Permutation transformation.
Expand Down Expand Up @@ -114,12 +113,24 @@ class Indexed(AbstractBijection):
bijection: Bijection that is compatible with the subset of x indexed by idxs.
idxs: Indices (Integer, a slice, or an ndarray with integer/bool dtype) of the
transformed portion.
shape: Shape of the bijection. Defaults to None.
shape: Shape of the bijection.
"""

bijection: AbstractBijection
idxs: int | slice | Array | tuple
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None

def __init__(
self,
bijection: AbstractBijection,
idxs: int | slice | Array | tuple,
shape: tuple[int, ...],
):
self.bijection = bijection
self.idxs = idxs
self.shape = shape
self.cond_shape = bijection.cond_shape

def __check_init__(self):
expected_shape = jnp.zeros(self.shape)[self.idxs].shape
Expand All @@ -138,10 +149,6 @@ def inverse_and_log_det(self, y, condition=None):
x, log_det = self.bijection.inverse_and_log_det(y[self.idxs], condition)
return y.at[self.idxs].set(x), log_det

@property
def cond_shape(self):
return self.bijection.cond_shape


class Identity(AbstractBijection):
"""The identity bijection.
Expand Down Expand Up @@ -174,6 +181,7 @@ class EmbedCondition(AbstractBijection):
"""

bijection: AbstractBijection
shape: tuple[int, ...]
cond_shape: tuple[int, ...]
embedding_net: Callable

Expand All @@ -185,6 +193,7 @@ def __init__(
):
self.bijection = bijection
self.embedding_net = embedding_net
self.shape = self.bijection.shape
self.cond_shape = raw_cond_shape

def transform_and_log_det(self, x, condition=None):
Expand All @@ -195,10 +204,6 @@ def inverse_and_log_det(self, y, condition=None):
condition = self.embedding_net(condition)
return self.bijection.inverse_and_log_det(y, condition)

@property
def shape(self):
return self.bijection.shape


class Reshape(AbstractBijection):
"""Wraps bijection methods with reshaping operations.
Expand Down
3 changes: 1 addition & 2 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def _get_sample_keys(
sample_shape: tuple[int, ...],
condition,
):

if not dtypes.issubdtype(key.dtype, dtypes.prng_key):
raise TypeError("New-style typed JAX PRNG keys required.")

Expand Down Expand Up @@ -339,6 +338,7 @@ class Transformed(AbstractTransformed):
>>> bijection = Affine(1)
>>> transformed = Transformed(normal, bijection)
"""

base_dist: AbstractDistribution
bijection: AbstractBijection

Expand All @@ -348,7 +348,6 @@ def __init__(self, base_dist: AbstractDistribution, bijection: AbstractBijection
self.bijection = bijection



class AbstractLocScaleDistribution(AbstractTransformed):
"""Abstract distribution class for affine transformed distributions."""

Expand Down