diff --git a/.gitignore b/.gitignore index ffb652e..5fd84d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ *.py[cod] .pylintrc - +uv.lock # C extensions *.so diff --git a/paramax/__init__.py b/paramax/__init__.py index d3e68c9..4c58f2a 100644 --- a/paramax/__init__.py +++ b/paramax/__init__.py @@ -6,6 +6,7 @@ AbstractUnwrappable, NonTrainable, Parameterize, + RealToIncreasingOnInterval, WeightNormalization, contains_unwrappables, non_trainable, diff --git a/paramax/wrappers.py b/paramax/wrappers.py index 6b6d7d1..bdd65f8 100644 --- a/paramax/wrappers.py +++ b/paramax/wrappers.py @@ -6,13 +6,14 @@ from abc import abstractmethod from collections.abc import Callable -from typing import Any, Generic, TypeVar +from functools import partial +from typing import Any, Generic, Literal, TypeVar import equinox as eqx import jax import jax.numpy as jnp from jax import lax -from jax.nn import softplus +from jax.nn import softmax, softplus from jax.tree_util import tree_leaves from jaxtyping import Array, PyTree @@ -159,6 +160,84 @@ def unwrap(self) -> T: return eqx.combine(lax.stop_gradient(differentiable), static) +class RealToIncreasingOnInterval(AbstractUnwrappable[Array]): + """Map an unconstrained vector to increasing points on a fixed interval. + + The input vector is transformed via a softmax into positive widths, to fill + the interval after adding minimum width. The cumulative sum of the widths + produces the incresing points. + + If an array of size d is used, the result has size d+1 if both + endpoints are included, d if one is included, and d-1 if neither + are included. + + Args: + arr: Unconstrained vector parameterizing the widths. + interval: (lower, upper) bounds of the interval. + min_width: Minimum spacing between consecutive points. + include_endpoints: Which endpoints to include: "both", "neither", + "lower", or "upper". + """ + + arr: Array + interval: tuple[float | int, float | int] + min_width: float + include_endpoints: Literal["both", "neither", "lower", "upper"] + + def __init__( + self, + arr: Array, + interval: tuple[float | int, float | int], + *, + min_width: float, + include_endpoints: Literal["both", "neither", "lower", "upper"], + ): + scale = interval[1] - interval[0] + n_widths = arr.shape[-1] + + if min_width <= 0: + raise ValueError("min_width must be greater than or equal to 0.") + + if interval[1] <= interval[0]: + raise ValueError("interval[1] must be greater than interval[0]") + + if min_width * n_widths > scale: + raise ValueError( + "min_width*n_widths is greater than the interval width, so cannot be " + "satisfied." + ) + + if include_endpoints not in ["both", "neither", "lower", "upper"]: + raise ValueError( + "include_endpoints must be one of 'both', 'neither', 'lower', or 'upper'", + ) + + self.arr = arr + self.interval = interval + self.min_width = min_width + self.include_endpoints = include_endpoints + + def unwrap(self) -> Array: + scale = self.interval[1] - self.interval[0] + + @partial(jnp.vectorize, signature="(a)->(b)") + def _unwrap(arr): + n_widths = self.arr.shape[-1] + widths = softmax(arr) + free = scale - self.min_width * n_widths + widths = self.min_width + free * widths + lower, upper = jnp.array([self.interval[0]]), jnp.array([self.interval[1]]) + return jnp.concatenate( + [ + *([lower] if self.include_endpoints in ("lower", "both") else []), + jnp.cumsum(widths, axis=-1)[:-1] + lower, + *([upper] if self.include_endpoints in ("upper", "both") else []), + ] + ) + + return _unwrap(self.arr) + + class WeightNormalization(AbstractUnwrappable[Array]): """Applies weight normalization (https://arxiv.org/abs/1602.07868). diff --git a/pyproject.toml b/pyproject.toml index 28a702f..1ff37ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ license = { file = "LICENSE" } name = "paramax" readme = "README.md" requires-python = ">=3.10" -version = "0.0.3" +version = "0.0.4" [project.urls] repository = "https://github.com/danielward27/paramax" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 5e88eaa..320ada4 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,3 +1,4 @@ +# $$ from math import prod import equinox as eqx @@ -9,11 +10,52 @@ from paramax.wrappers import ( NonTrainable, Parameterize, + RealToIncreasingOnInterval, WeightNormalization, non_trainable, unwrap, ) +to_inverval_test_cases = [ + { + "arr": jnp.zeros(2), + "include_ends": "both", + "expected": jnp.array([-1, 0.5, 2]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "lower", + "expected": jnp.array([-1, 0.5]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "upper", + "expected": jnp.array([0.5, 2]), + }, + { + "arr": jnp.zeros(2), + "include_ends": "neither", + "expected": jnp.array([0.5]), + }, + { + "arr": jnp.array([-100, 0]), + "include_ends": "both", + "expected": jnp.array([-1, -0.9, 2]), + }, +] + + +@pytest.mark.parametrize("case", to_inverval_test_cases) +def test_RealToIncreasingOnInterval(case): + real_to_inc = RealToIncreasingOnInterval( + case["arr"], + (-1, 2), + min_width=0.1, + include_endpoints=case["include_ends"], + ) + result = unwrap(real_to_inc) + assert pytest.approx(case["expected"]) == result + def test_Parameterize(): diag = Parameterize(jnp.diag, jnp.ones(3)) @@ -29,7 +71,6 @@ def test_nested_unwrap(): def test_non_trainable(): - model = (jnp.ones(3), 1) model = non_trainable(model) @@ -54,9 +95,12 @@ def test_WeightNormalization(): test_cases = { - "NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)), - "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)), + "NonTrainable": lambda key: NonTrainable(jr.normal(key, (10,))), + "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, (10,))), "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), + "RealToIncreasingOnInterval": lambda key: RealToIncreasingOnInterval( + jnp.zeros(10), (-7, 5), min_width=0.2, include_endpoints="upper" + ), }