Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ flow = block_neural_autoregressive_flow(
)

flow, losses = fit_to_data(
train_key,
flow,
x,
key=train_key,
dist=flow,
data=x,
learning_rate=5e-3,
max_epochs=200,
)
Expand Down
13 changes: 6 additions & 7 deletions docs/examples/conditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -68,14 +68,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 34%|███▍ | 34/100 [00:09<00:19, 3.45it/s, train=1.35, val=1.33 (Max patience reached)]\n"
" 34%|███▍ | 34/100 [00:12<00:24, 2.70it/s, train=1.35, val=1.33 (Max patience reached)]\n"
]
}
],
Expand All @@ -91,8 +91,7 @@
"flow, losses = fit_to_data(\n",
" subkey,\n",
" flow,\n",
" x,\n",
" u,\n",
" data=(x, u),\n",
" learning_rate=5e-2,\n",
" max_patience=10,\n",
")"
Expand All @@ -108,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down
26 changes: 13 additions & 13 deletions docs/examples/constrained.ipynb

Large diffs are not rendered by default.

36 changes: 19 additions & 17 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

34 changes: 14 additions & 20 deletions flowjax/train/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def fit_to_key_based_loss(
return eqx.combine(params, static), losses




def fit_to_data(
key: PRNGKeyArray,
dist: PyTree, # Custom losses may support broader types than AbstractDistribution
*data: ArrayLike,
data: ArrayLike | tuple[ArrayLike, ...] = (),
*,
loss_fn: Callable | None = None,
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
Expand All @@ -104,11 +103,10 @@ def fit_to_data(
Args:
key: Jax random seed.
dist: The pytree to train (usually a distribution).
*data: A variable number of data arrays with matching shape on axis 0. Batches
of each array are passed to the loss function as positional arguments
(see documentation for ``loss_fn``). Commonly this is a single array for
unconditional density estimation, or two arrays ``target, condition)``
for conditional density estimation.
data: An array or tuple of arrays passed as positional arguments to the
loss function (see documentation for ``loss_fn``). Commonly this is a
single array for unconditional density estimation, or two arrays
``(target, condition)`` for conditional density estimation.
learning_rate: The learning rate for adam optimizer. Ignored if optimizer is
provided.
optimizer: Optax optimizer. Defaults to None.
Expand All @@ -123,27 +121,23 @@ def fit_to_data(
was reached (when True), or the parameters after the last update (when
False). Defaults to True.
show_progress: Whether to show progress bar. Defaults to True.
x: Deprecated. Pass in as positional argument instead. See variable argument
*data.
condition: Deprecated way to pass conditioning variables. Pass as a positional
argument instead. See variable argument *data.
x: Deprecated. Pass in data instead.
condition: Deprecated. Pass in data instead.

Returns:
A tuple containing the trained distribution and the losses.
"""
data = (data,) if isinstance(data, ArrayLike) else data

def _handle_deprecation(data, x, condition):
# TODO This function handles the deprecation of x and condition, so will
# be removed when deprecated.

if data != () and x is not None: # Note x passed as key word in this case
raise ValueError("Use data argument only (pass x in data).")

# be removed when deprecated. The default to tuple for data should also be
# removed.
if x is not None or condition is not None:
warnings.warn(
"Passing x and condition as key word arguments is deprecated and will "
"be removed in the next major version. Pass both x and condition as "
"positional arguments. See documentation of *data. This change allows "
"Keyword arguments x and condition are deprecated and will "
"be removed in the next major version. Pass both x and condition "
"to the data argument. See documentation of data. This change allows "
"for more flexibility in the number of arrays required by a loss.",
FutureWarning,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_train/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def test_data_fit():
before = eqx.filter(flow, eqx.is_inexact_array)
x = random.normal(random.key(0), (100, dim))
flow, losses = fit_to_data(
random.key(0),
flow,
x,
key=random.key(0),
dist=flow,
data=x,
max_epochs=1,
batch_size=50,
)
Expand Down