From 153676bde9412b375cdfdd6dbe9d051b215035bd Mon Sep 17 00:00:00 2001 From: Daniel Ward Date: Wed, 21 Jan 2026 21:29:45 +0000 Subject: [PATCH 1/3] Add RealToIncreasingOnInterval --- .gitignore | 2 +- paramax/__init__.py | 1 + paramax/wrappers.py | 70 ++++++++++++++++++++++++++++++++++++++++-- tests/test_wrappers.py | 53 ++++++++++++++++++++++++++++++-- 4 files changed, 120 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index ffb652e..5fd84d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ *.py[cod] .pylintrc - +uv.lock # C extensions *.so diff --git a/paramax/__init__.py b/paramax/__init__.py index d3e68c9..4c58f2a 100644 --- a/paramax/__init__.py +++ b/paramax/__init__.py @@ -6,6 +6,7 @@ AbstractUnwrappable, NonTrainable, Parameterize, + RealToIncreasingOnInterval, WeightNormalization, contains_unwrappables, non_trainable, diff --git a/paramax/wrappers.py b/paramax/wrappers.py index 6b6d7d1..50c6965 100644 --- a/paramax/wrappers.py +++ b/paramax/wrappers.py @@ -6,13 +6,14 @@ from abc import abstractmethod from collections.abc import Callable -from typing import Any, Generic, TypeVar +from functools import partial +from typing import Any, Generic, Literal, TypeVar import equinox as eqx import jax import jax.numpy as jnp from jax import lax -from jax.nn import softplus +from jax.nn import softmax, softplus from jax.tree_util import tree_leaves from jaxtyping import Array, PyTree @@ -159,6 +160,71 @@ def unwrap(self) -> T: return eqx.combine(lax.stop_gradient(differentiable), static) +class RealToIncreasingOnInterval(AbstractUnwrappable[Array]): + """Unconstrained vector to increasing on a fixed interval. + + Unconstrained vector is passed into softmax to obtain widths, + which are cumulatively summed and scaled to fit the interval + (or the remainder not used by min_width). + + Note an array of size d parameterizes widths, so maps to an array + of size d+1 if both endpoints are included, d if one is included, + and d-1 if neither are included. + """ + + arr: Array + interval: tuple[float | int, float | int] + min_width: float + include_ends: Literal["both", "neither", "lower", "upper"] + + def __init__( + self, + arr: Array, + interval: tuple[float | int, float | int], + *, + min_width: float, + include_ends: Literal["both", "neither", "lower", "upper"], + ): + scale = interval[1] - interval[0] + n_widths = arr.shape[-1] + + if min_width <= 0: + raise ValueError("min_width must be greater than or equal to 0.") + + if interval[1] <= interval[0]: + raise ValueError("interval[1] must be greater than interval[0]") + + if min_width * n_widths > scale: + raise ValueError( + "min_width*n_widths is greater than the interval width, so cannot be " + "satisfied." + ) + self.arr = arr + self.interval = interval + self.min_width = min_width + self.include_ends = include_ends + + def unwrap(self) -> Array: + scale = self.interval[1] - self.interval[0] + + @partial(jnp.vectorize, signature="(a)->(b)") + def _unwrap(arr): + n_widths = self.arr.shape[-1] + widths = softmax(arr) + free = scale - self.min_width * n_widths + widths = self.min_width + free * widths + lower, upper = jnp.array([self.interval[0]]), jnp.array([self.interval[1]]) + return jnp.concatenate( + [ + *([lower] if self.include_ends in ("lower", "both") else []), + jnp.cumsum(widths, axis=-1)[:-1] + lower, + *([upper] if self.include_ends in ("upper", "both") else []), + ] + ) + + return _unwrap(self.arr) + + class WeightNormalization(AbstractUnwrappable[Array]): """Applies weight normalization (https://arxiv.org/abs/1602.07868). diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 5e88eaa..0252d51 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,3 +1,4 @@ +# $$ from math import prod import equinox as eqx @@ -9,11 +10,55 @@ from paramax.wrappers import ( NonTrainable, Parameterize, + RealToIncreasingOnInterval, WeightNormalization, non_trainable, unwrap, ) +# %% + + +to_inverval_test_cases = [ + { + "arr": jnp.zeros(2), + "include_ends": "both", + "expected": jnp.array([-1, 0.5, 2]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "lower", + "expected": jnp.array([-1, 0.5]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "upper", + "expected": jnp.array([0.5, 2]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "neither", + "expected": jnp.array([0.5]), + }, + { + "arr": jnp.array([-100, 0]), + "include_ends": "both", + "expected": jnp.array([-1, -0.9, 2]), + }, +] + + +@pytest.mark.parametrize("case", to_inverval_test_cases) +def test_RealToIncreasingOnInterval(case): + real_to_inc = RealToIncreasingOnInterval( + case["arr"], + (-1, 2), + min_width=0.1, + include_ends=case["include_ends"], + ) + result = unwrap(real_to_inc) + assert pytest.approx(case["expected"]) == result + def test_Parameterize(): diag = Parameterize(jnp.diag, jnp.ones(3)) @@ -29,7 +74,6 @@ def test_nested_unwrap(): def test_non_trainable(): - model = (jnp.ones(3), 1) model = non_trainable(model) @@ -54,9 +98,12 @@ def test_WeightNormalization(): test_cases = { - "NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)), - "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)), + "NonTrainable": lambda key: NonTrainable(jr.normal(key, (10,))), + "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, (10,))), "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), + "RealToIncreasingOnInterval": lambda key: RealToIncreasingOnInterval( + jnp.zeros(10), (-7, 5), min_width=0.2, include_ends="upper" + ), } From 5e8ca99a7c7157424a61bce9f5b0509de2bae291 Mon Sep 17 00:00:00 2001 From: Daniel Ward Date: Fri, 23 Jan 2026 08:28:34 +0000 Subject: [PATCH 2/3] Tidy up and docs --- paramax/wrappers.py | 37 +++++++++++++++++++++++++------------ pyproject.toml | 2 +- tests/test_wrappers.py | 4 ++-- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/paramax/wrappers.py b/paramax/wrappers.py index 50c6965..bdd65f8 100644 --- a/paramax/wrappers.py +++ b/paramax/wrappers.py @@ -161,21 +161,28 @@ def unwrap(self) -> T: class RealToIncreasingOnInterval(AbstractUnwrappable[Array]): - """Unconstrained vector to increasing on a fixed interval. + """Map an unconstrained vector to increasing points on a fixed interval. - Unconstrained vector is passed into softmax to obtain widths, - which are cumulatively summed and scaled to fit the interval - (or the remainder not used by min_width). + The input vector is transformed via a softmax into positive widths, to fill + the interval after adding minimum width. The cumulative sum of the widths + produces the incresing points. - Note an array of size d parameterizes widths, so maps to an array - of size d+1 if both endpoints are included, d if one is included, - and d-1 if neither are included. + If an array of size d is used, the result has size d+1 if both + endpoints are included, d if one is included, and d-1 if neither + are included. + + Args: + arr: Unconstrained vector parameterizing the widths. + interval: (lower, upper) bounds of the interval. + min_width: Minimum spacing between consecutive points. + include_endpoints: Which endpoints to include: "both", "neither", + "lower", or "upper". """ arr: Array interval: tuple[float | int, float | int] min_width: float - include_ends: Literal["both", "neither", "lower", "upper"] + include_endpoints: Literal["both", "neither", "lower", "upper"] def __init__( self, @@ -183,7 +190,7 @@ def __init__( interval: tuple[float | int, float | int], *, min_width: float, - include_ends: Literal["both", "neither", "lower", "upper"], + include_endpoints: Literal["both", "neither", "lower", "upper"], ): scale = interval[1] - interval[0] n_widths = arr.shape[-1] @@ -199,10 +206,16 @@ def __init__( "min_width*n_widths is greater than the interval width, so cannot be " "satisfied." ) + + if include_endpoints not in ["both", "neither", "lower", "upper"]: + raise ValueError( + "include_endpoints must be one of 'both', 'neither', 'lower', or 'upper'", + ) + self.arr = arr self.interval = interval self.min_width = min_width - self.include_ends = include_ends + self.include_endpoints = include_endpoints def unwrap(self) -> Array: scale = self.interval[1] - self.interval[0] @@ -216,9 +229,9 @@ def _unwrap(arr): lower, upper = jnp.array([self.interval[0]]), jnp.array([self.interval[1]]) return jnp.concatenate( [ - *([lower] if self.include_ends in ("lower", "both") else []), + *([lower] if self.include_endpoints in ("lower", "both") else []), jnp.cumsum(widths, axis=-1)[:-1] + lower, - *([upper] if self.include_ends in ("upper", "both") else []), + *([upper] if self.include_endpoints in ("upper", "both") else []), ] ) diff --git a/pyproject.toml b/pyproject.toml index 28a702f..1ff37ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ license = { file = "LICENSE" } name = "paramax" readme = "README.md" requires-python = ">=3.10" -version = "0.0.3" +version = "0.0.4" [project.urls] repository = "https://github.com/danielward27/paramax" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 0252d51..035db68 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -54,7 +54,7 @@ def test_RealToIncreasingOnInterval(case): case["arr"], (-1, 2), min_width=0.1, - include_ends=case["include_ends"], + include_endpoints=case["include_ends"], ) result = unwrap(real_to_inc) assert pytest.approx(case["expected"]) == result @@ -102,7 +102,7 @@ def test_WeightNormalization(): "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, (10,))), "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), "RealToIncreasingOnInterval": lambda key: RealToIncreasingOnInterval( - jnp.zeros(10), (-7, 5), min_width=0.2, include_ends="upper" + jnp.zeros(10), (-7, 5), min_width=0.2, include_endpoints="upper" ), } From d9017faf15351d5b89cad4bcc8b31dcaa1f56402 Mon Sep 17 00:00:00 2001 From: Daniel Ward Date: Fri, 23 Jan 2026 08:31:12 +0000 Subject: [PATCH 3/3] Remove comment --- tests/test_wrappers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 035db68..320ada4 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,9 +16,6 @@ unwrap, ) -# %% - - to_inverval_test_cases = [ { "arr": jnp.zeros(2),