From 836c87bf81e1d2181a20111715e74e5be3cec313 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 13 Feb 2025 18:04:37 +0100 Subject: [PATCH 1/4] Allow datasets with several arrays in fit_to_data --- flowjax/train/loops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flowjax/train/loops.py b/flowjax/train/loops.py index bffb8bcf..7512a5a5 100644 --- a/flowjax/train/loops.py +++ b/flowjax/train/loops.py @@ -78,7 +78,7 @@ def fit_to_key_based_loss( def fit_to_data( key: PRNGKeyArray, dist: PyTree, # Custom losses may support broader types than AbstractDistribution - x: ArrayLike, + x: ArrayLike | tuple[ArrayLike, ...], *, condition: ArrayLike | None = None, loss_fn: Callable | None = None, @@ -120,7 +120,12 @@ def fit_to_data( Returns: A tuple containing the trained distribution and the losses. """ - data = (x,) if condition is None else (x, condition) + if isinstance(x, tuple): + data = x + else: + data = (x,) + if condition is not None: + data = (*data, condition) data = tuple(jnp.asarray(a) for a in data) if loss_fn is None: From 7364b4d93e5728810932be80a84739971d1b252a Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 13 Feb 2025 18:09:49 +0100 Subject: [PATCH 2/4] Allow passing in optimization states in fit_to_data --- flowjax/train/loops.py | 16 ++++++++-- tests/test_train/test_data_fit.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/flowjax/train/loops.py b/flowjax/train/loops.py index 7512a5a5..68f06834 100644 --- a/flowjax/train/loops.py +++ b/flowjax/train/loops.py @@ -90,6 +90,8 @@ def fit_to_data( val_prop: float = 0.1, return_best: bool = True, show_progress: bool = True, + opt_state: optax.OptState | None = None, + return_opt_state: bool = False, ): r"""Train a PyTree (e.g. a distribution) to samples from the target. @@ -116,9 +118,14 @@ def fit_to_data( was reached (when True), or the parameters after the last update (when False). Defaults to True. show_progress: Whether to show progress bar. Defaults to True. + opt_state: Optinal initial state of the optimizer. + return_opt_state: Whether to return the optimizer state. Returns: - A tuple containing the trained distribution and the losses. + A tuple containing the trained distribution and a dict with optimization + information like the losses and the optimizer state. + + If an opt_state is provided, it will also return the new opt_state. """ if isinstance(x, tuple): data = x @@ -140,7 +147,8 @@ def fit_to_data( is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable), ) best_params = params - opt_state = optimizer.init(params) + if opt_state is None: + opt_state = optimizer.init(params) # train val split key, subkey = jr.split(key) @@ -189,4 +197,8 @@ def fit_to_data( params = best_params if return_best else params dist = eqx.combine(params, static) + + if return_opt_state: + return dist, losses, opt_state + return dist, losses diff --git a/tests/test_train/test_data_fit.py b/tests/test_train/test_data_fit.py index 553f341a..06d4bba3 100644 --- a/tests/test_train/test_data_fit.py +++ b/tests/test_train/test_data_fit.py @@ -1,6 +1,8 @@ import equinox as eqx import jax.numpy as jnp from jax import random +from paramax.wrappers import unwrap +import optax from flowjax.bijections import Affine from flowjax.distributions import Normal, Transformed @@ -29,3 +31,54 @@ def test_data_fit(): assert jnp.all(before.bijection.loc != after.bijection.loc) assert isinstance(losses["train"][0], float) assert isinstance(losses["val"][0], float) + + +def test_data_fit_opt_state(): + dim = 3 + mean, std = jnp.ones(dim), jnp.ones(dim) + base_dist = Normal(mean, std) + flow = Transformed(base_dist, Affine(jnp.ones(dim), jnp.ones(dim))) + + # All params should change by default + before = eqx.filter(flow, eqx.is_inexact_array) + values = random.normal(random.key(0), (100, dim)) + log_probs = random.normal(random.key(1), (100,)) + + def loss_fn(params, static, values, log_probs, key=None): + flow = unwrap(eqx.combine(params, static, is_leaf=eqx.is_inexact_array)) + return (log_probs - flow.log_prob(params, values)).mean() + + flow, losses, opt_state = fit_to_data( + random.key(0), + dist=flow, + x=(values, log_probs), + max_epochs=1, + batch_size=50, + return_opt_state=True, + ) + after = eqx.filter(flow, eqx.is_inexact_array) + + assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc) + assert jnp.all(before.bijection.loc != after.bijection.loc) + assert isinstance(losses["train"][0], float) + assert isinstance(losses["val"][0], float) + + # Continue training on new data + values = random.normal(random.key(2), (100, dim)) + log_probs = random.normal(random.key(3), (100,)) + + flow, losses, opt_state = fit_to_data( + random.key(4), + dist=flow, + x=(values, log_probs), + max_epochs=1, + batch_size=50, + return_opt_state=True, + opt_state=opt_state, + ) + after = eqx.filter(flow, eqx.is_inexact_array) + + assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc) + assert jnp.all(before.bijection.loc != after.bijection.loc) + assert isinstance(losses["train"][0], float) + assert isinstance(losses["val"][0], float) From ff3d30e5f0c5845741166f3885f52667989c29cb Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 18 Feb 2025 21:12:27 +0100 Subject: [PATCH 3/4] Allow multiple positional data arrays in fit_to_data --- flowjax/train/loops.py | 29 ++++++++++++++++++++--------- tests/test_train/test_data_fit.py | 23 +++++++++++++++++------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/flowjax/train/loops.py b/flowjax/train/loops.py index 68f06834..7fb2498b 100644 --- a/flowjax/train/loops.py +++ b/flowjax/train/loops.py @@ -1,6 +1,7 @@ """Training loops.""" from collections.abc import Callable +from warnings import warn import equinox as eqx import jax.numpy as jnp @@ -28,6 +29,8 @@ def fit_to_key_based_loss( learning_rate: float = 5e-4, optimizer: optax.GradientTransformation | None = None, show_progress: bool = True, + opt_state: optax.OptState | None = None, + return_opt_state: bool = False, ): """Train a pytree, using a loss with params, static and key as arguments. @@ -43,6 +46,8 @@ def fit_to_key_based_loss( learning_rate: The adam learning rate. Ignored if optimizer is provided. optimizer: Optax optimizer. Defaults to None. show_progress: Whether to show progress bar. Defaults to True. + opt_state: Optinal initial state of the optimizer. + return_opt_state: Whether to return the optimizer state. Returns: A tuple containing the trained pytree and the losses. @@ -55,7 +60,8 @@ def fit_to_key_based_loss( eqx.is_inexact_array, is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable), ) - opt_state = optimizer.init(params) + if opt_state is None: + opt_state = optimizer.init(params) losses = [] @@ -72,14 +78,15 @@ def fit_to_key_based_loss( ) losses.append(loss.item()) keys.set_postfix({"loss": loss.item()}) + if return_opt_state: + return eqx.combine(params, static), losses, opt_state return eqx.combine(params, static), losses def fit_to_data( key: PRNGKeyArray, dist: PyTree, # Custom losses may support broader types than AbstractDistribution - x: ArrayLike | tuple[ArrayLike, ...], - *, + *data: ArrayLike, condition: ArrayLike | None = None, loss_fn: Callable | None = None, learning_rate: float = 5e-4, @@ -103,11 +110,14 @@ def fit_to_data( Args: key: Jax random seed. dist: The pytree to train (usually a distribution). - x: Samples from target distribution. + data: Samples from target distribution. If several arrays are passed, each one + is split into batches along the first axes, and one batch of each is + passed into the loss function. learning_rate: The learning rate for adam optimizer. Ignored if optimizer is provided. optimizer: Optax optimizer. Defaults to None. - condition: Conditioning variables. Defaults to None. + condition: Conditioning variables. Defaults to None. This argument is + deprecated, you can pass this information as the last `x` argument. loss_fn: Loss function. Defaults to MaximumLikelihoodLoss. max_epochs: Maximum number of epochs. Defaults to 100. max_patience: Number of consecutive epochs with no validation loss improvement @@ -127,11 +137,12 @@ def fit_to_data( If an opt_state is provided, it will also return the new opt_state. """ - if isinstance(x, tuple): - data = x - else: - data = (x,) if condition is not None: + raise warn( + "The `condition` argument is deprecated. " + "You can pass condition data as additonal data arrays.", + DeprecationWarning, + ) data = (*data, condition) data = tuple(jnp.asarray(a) for a in data) diff --git a/tests/test_train/test_data_fit.py b/tests/test_train/test_data_fit.py index 06d4bba3..84810f0c 100644 --- a/tests/test_train/test_data_fit.py +++ b/tests/test_train/test_data_fit.py @@ -20,13 +20,22 @@ def test_data_fit(): x = random.normal(random.key(0), (100, dim)) flow, losses = fit_to_data( random.key(0), - dist=flow, - x=x, + flow, + x, max_epochs=1, batch_size=50, ) after = eqx.filter(flow, eqx.is_inexact_array) + flow2, losses2, opt_state = fit_to_data( + random.key(0), + flow, + x, + max_epochs=1, + batch_size=50, + return_opt_state=True, + ) + assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc) assert jnp.all(before.bijection.loc != after.bijection.loc) assert isinstance(losses["train"][0], float) @@ -50,8 +59,9 @@ def loss_fn(params, static, values, log_probs, key=None): flow, losses, opt_state = fit_to_data( random.key(0), - dist=flow, - x=(values, log_probs), + flow, + values, + log_probs, max_epochs=1, batch_size=50, return_opt_state=True, @@ -69,8 +79,9 @@ def loss_fn(params, static, values, log_probs, key=None): flow, losses, opt_state = fit_to_data( random.key(4), - dist=flow, - x=(values, log_probs), + flow, + values, + log_probs, max_epochs=1, batch_size=50, return_opt_state=True, From b4bb43f7f6abfffa437cf35c4e734dff9d5780d2 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 25 Feb 2025 12:18:42 +0100 Subject: [PATCH 4/4] Repair old api for fit_to_data --- flowjax/train/loops.py | 8 +++++--- tests/test_train/test_data_fit.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/flowjax/train/loops.py b/flowjax/train/loops.py index 7fb2498b..fb593faf 100644 --- a/flowjax/train/loops.py +++ b/flowjax/train/loops.py @@ -86,6 +86,7 @@ def fit_to_key_based_loss( def fit_to_data( key: PRNGKeyArray, dist: PyTree, # Custom losses may support broader types than AbstractDistribution + x: ArrayLike, *data: ArrayLike, condition: ArrayLike | None = None, loss_fn: Callable | None = None, @@ -110,9 +111,9 @@ def fit_to_data( Args: key: Jax random seed. dist: The pytree to train (usually a distribution). - data: Samples from target distribution. If several arrays are passed, each one - is split into batches along the first axes, and one batch of each is - passed into the loss function. + x: Samples from target distribution. + data: Extra arrays that are sliced into batches like the samples, and passed + to the loss function. learning_rate: The learning rate for adam optimizer. Ignored if optimizer is provided. optimizer: Optax optimizer. Defaults to None. @@ -144,6 +145,7 @@ def fit_to_data( DeprecationWarning, ) data = (*data, condition) + data = (x, *data) data = tuple(jnp.asarray(a) for a in data) if loss_fn is None: diff --git a/tests/test_train/test_data_fit.py b/tests/test_train/test_data_fit.py index 84810f0c..d2017e5a 100644 --- a/tests/test_train/test_data_fit.py +++ b/tests/test_train/test_data_fit.py @@ -20,8 +20,8 @@ def test_data_fit(): x = random.normal(random.key(0), (100, dim)) flow, losses = fit_to_data( random.key(0), - flow, - x, + dist=flow, + x=x, max_epochs=1, batch_size=50, ) @@ -29,8 +29,8 @@ def test_data_fit(): flow2, losses2, opt_state = fit_to_data( random.key(0), - flow, - x, + dist=flow, + x=x, max_epochs=1, batch_size=50, return_opt_state=True,