Skip to content

Releases: danielward27/flowjax

v18.0.0

01 Feb 10:49
01ef1b5

Choose a tag to compare

What's Changed

Breaking changes

Copied from #231 for visability
Introduces some breaking changes to tidy up RationalQuadraticSpline:

  • Uses RealToIncreasingOnInterval from paramax, which is simpler than the approach previously used.

  • When constructing a rational quadratic spline, the softmax_adjust: float | int = 1e-2 argument has been removed.

  • Instead, minimum widths of the bins of the spline are controlled by a new min_width argument.

  • As a result of the changed parameterization, results may change somewhat.

Smaller breaking changes:

  • Removed the RationalQuadraticSpline.derivative method and RationalQuadraticSpline.min_deriv attributes.

Full Changelog: v17.2.1...v18.0.0

v17.2.1

19 Oct 18:24
3822de2

Choose a tag to compare

What's Changed

Full Changelog: v17.2.0...v17.2.1

v17.2.0

06 Jul 14:35
a39bd83

Choose a tag to compare

This release allows passing an arbitrary number of arrays to fit_to_data. One motivation for this is to allow losses taking more arrays (or different arrays than just x, and condition) as inputs. For example, weighted training can now be performed quite simply, with a pattern like:

from paramax import unwrap
import equinox as eqx

def weighted_maximum_likelihood(params, static, x, weights, key=None):
    dist = unwrap(eqx.combine(params, static))
    return -(weights * dist.log_prob(x)).sum() / len(weights)

flow, losses = fit_to_data(
    train_key,
    flow,
    (x, weights),
    learning_rate=5e-3,
    max_epochs=200,
    loss_fn=weighted_maximum_likelihood,
)

This adds a warning for deprecating using x and condition as key word arguments to fit_to_data.

What's Changed

Full Changelog: v17.1.2...v17.2.0

v17.1.2

07 May 13:02
9686a70

Choose a tag to compare

What's Changed

Full Changelog: v17.1.1...v17.1.2

v17.1.1

05 Mar 09:38
5957107

Choose a tag to compare

Release to (hopefully) fix #214.

What's Changed

Full Changelog: v17.1.0...v17.1.1

v17.1.0

19 Feb 16:39
262d7a9

Choose a tag to compare

No breaking changes. Technically, the structure of the triangular_spline_flow, PyTree has changed (from the update to using Sandwich), which could be a breaking change, but I'm considering this minor enough to not bump the major version. The log determinant calculation was wrong for NumericalInverse, this release fixes this bug, in addition to adding some bijections.

What's Changed

New Contributors

Full Changelog: v17.0.2...v17.1.0

v17.0.2

19 Dec 09:50
5bcdc93

Choose a tag to compare

Updates to the workflow packages led to some errors in the release to PyPi, which was fixed with this release. Below are the changes from v16-v17, which are of more interest to users.

Breaking Changes

  • The AbstractBijection class now implements transform and inverse methods by indexing the outputs of transform_and_log_det and inverse_and_log_det. AFAIK, under JIT, dead code elimination prevents computation of log_det where unnecessary, minimizing the benefits of directly implementing these methods. Custom bijections should generally avoid overriding transform and inverse.
  • The flowjax.wrappers module has been removed. Its functionality is now available in a new, separate package: paramax. Most functionality is unchanged when imported from paramax.
  • Partial has been renamed to Indexed, as the former was likely to be confused with functools.partial or jax.tree_util.Partial.
  • The deprecated fit_to_variational_target function has been removed. Use fit_to_key_based_loss instead.
  • Numerical inverse methods are now provided using the NumericalInverse composition. This means:
    • Users directly calling inverse methods on BlockAutoregressiveNetwork will encounter an error. To resolve this, supply an
      inverse method via NumericalInverse (as is done in block_neural_autoregressive_flow).
    • Users accessing the attributes of block_neural_autoregressive_flows might also face a breaking changes; an additional .bijection may be required to extract BlockAutoregressiveNetwork from NumericalInverse.
  • The uniform distribution is now non-trainable by default. Optimization of a uniform distribution seemed most commonly a mistake leading to e.g. violating support assumptions.
  • The root-finding algorithms in flowjax.bisection_search have been moved to flowjax.root_finding. Check the updated function names and documentation in the new module if you directly use these methods.

Apologies for any inconvenience caused by these breaking changes. If you encounter issues or have questions, please feel free to open an issue.


What's Changed

Full Changelog: v16.0.0...v17.0.0

v16.0.0

17 Oct 10:59
d711b18

Choose a tag to compare

What's Changed

Breaking changes:

  • Calls to get_ravelled_pytree_constructor will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually setting is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable).
  • fit_to_data now returns list of floats, rather than than a list of scalar arrays.

Note, fit_to_variational_target will be deprecated in the next version. This version adds its replacement fit_to_key_based_loss. This was primarily because of some defaults which were "bad", e.g. steps: int = 100, and return_best=True (see #188 for details). It also generalizes the name, as it can be used to fit any pytree, and doesn't have to be used with a variational inference loss function.

Full Changelog: v15.1.0...v16.0.0

v15.1.0

07 Oct 12:17

Choose a tag to compare

What's Changed

Full Changelog: v15.0.0...v15.1.0

v15.0.0

17 Sep 20:00
73163c9

Choose a tag to compare

Breaking changes:

  • Users must switch from old style PRNGKey arrays to new style ones (replacing jax.random.PRNGKey with jax.random.key . The old keys will be deprecated in JAX.
  • LogNormal is now a Exp transformed Normal, which is an implementation detail, unless you previously relied on log_normal.base_dist or log_normal.bijection.
  • recursive_unwrap was removed, as alone it didn't provide additional functionality.

What's Changed

Full Changelog: v14.0.0...v15.0.0