Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
*.py[cod]
.pylintrc

uv.lock
# C extensions
*.so

Expand Down
1 change: 1 addition & 0 deletions paramax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
AbstractUnwrappable,
NonTrainable,
Parameterize,
RealToIncreasingOnInterval,
WeightNormalization,
contains_unwrappables,
non_trainable,
Expand Down
83 changes: 81 additions & 2 deletions paramax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from abc import abstractmethod
from collections.abc import Callable
from typing import Any, Generic, TypeVar
from functools import partial
from typing import Any, Generic, Literal, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import lax
from jax.nn import softplus
from jax.nn import softmax, softplus
from jax.tree_util import tree_leaves
from jaxtyping import Array, PyTree

Expand Down Expand Up @@ -159,6 +160,84 @@ def unwrap(self) -> T:
return eqx.combine(lax.stop_gradient(differentiable), static)


class RealToIncreasingOnInterval(AbstractUnwrappable[Array]):
"""Map an unconstrained vector to increasing points on a fixed interval.

The input vector is transformed via a softmax into positive widths, to fill
the interval after adding minimum width. The cumulative sum of the widths
produces the incresing points.

If an array of size d is used, the result has size d+1 if both
endpoints are included, d if one is included, and d-1 if neither
are included.

Args:
arr: Unconstrained vector parameterizing the widths.
interval: (lower, upper) bounds of the interval.
min_width: Minimum spacing between consecutive points.
include_endpoints: Which endpoints to include: "both", "neither",
"lower", or "upper".
"""

arr: Array
interval: tuple[float | int, float | int]
min_width: float
include_endpoints: Literal["both", "neither", "lower", "upper"]

def __init__(
self,
arr: Array,
interval: tuple[float | int, float | int],
*,
min_width: float,
include_endpoints: Literal["both", "neither", "lower", "upper"],
):
scale = interval[1] - interval[0]
n_widths = arr.shape[-1]

if min_width <= 0:
raise ValueError("min_width must be greater than or equal to 0.")

if interval[1] <= interval[0]:
raise ValueError("interval[1] must be greater than interval[0]")

if min_width * n_widths > scale:
raise ValueError(
"min_width*n_widths is greater than the interval width, so cannot be "
"satisfied."
)

if include_endpoints not in ["both", "neither", "lower", "upper"]:
raise ValueError(
"include_endpoints must be one of 'both', 'neither', 'lower', or 'upper'",
)

self.arr = arr
self.interval = interval
self.min_width = min_width
self.include_endpoints = include_endpoints

def unwrap(self) -> Array:
scale = self.interval[1] - self.interval[0]

@partial(jnp.vectorize, signature="(a)->(b)")
def _unwrap(arr):
n_widths = self.arr.shape[-1]
widths = softmax(arr)
free = scale - self.min_width * n_widths
widths = self.min_width + free * widths
lower, upper = jnp.array([self.interval[0]]), jnp.array([self.interval[1]])
return jnp.concatenate(
[
*([lower] if self.include_endpoints in ("lower", "both") else []),
jnp.cumsum(widths, axis=-1)[:-1] + lower,
*([upper] if self.include_endpoints in ("upper", "both") else []),
]
)

return _unwrap(self.arr)


class WeightNormalization(AbstractUnwrappable[Array]):
"""Applies weight normalization (https://arxiv.org/abs/1602.07868).

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ license = { file = "LICENSE" }
name = "paramax"
readme = "README.md"
requires-python = ">=3.10"
version = "0.0.3"
version = "0.0.4"

[project.urls]
repository = "https://github.com/danielward27/paramax"
Expand Down
50 changes: 47 additions & 3 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# $$
from math import prod

import equinox as eqx
Expand All @@ -9,11 +10,52 @@
from paramax.wrappers import (
NonTrainable,
Parameterize,
RealToIncreasingOnInterval,
WeightNormalization,
non_trainable,
unwrap,
)

to_inverval_test_cases = [
{
"arr": jnp.zeros(2),
"include_ends": "both",
"expected": jnp.array([-1, 0.5, 2]),
},
{
"arr": jnp.zeros(2),
"include_ends": "lower",
"expected": jnp.array([-1, 0.5]),
},
{
"arr": jnp.zeros(2),
"include_ends": "upper",
"expected": jnp.array([0.5, 2]),
},
{
"arr": jnp.zeros(2),
"include_ends": "neither",
"expected": jnp.array([0.5]),
},
{
"arr": jnp.array([-100, 0]),
"include_ends": "both",
"expected": jnp.array([-1, -0.9, 2]),
},
]


@pytest.mark.parametrize("case", to_inverval_test_cases)
def test_RealToIncreasingOnInterval(case):
real_to_inc = RealToIncreasingOnInterval(
case["arr"],
(-1, 2),
min_width=0.1,
include_endpoints=case["include_ends"],
)
result = unwrap(real_to_inc)
assert pytest.approx(case["expected"]) == result


def test_Parameterize():
diag = Parameterize(jnp.diag, jnp.ones(3))
Expand All @@ -29,7 +71,6 @@ def test_nested_unwrap():


def test_non_trainable():

model = (jnp.ones(3), 1)
model = non_trainable(model)

Expand All @@ -54,9 +95,12 @@ def test_WeightNormalization():


test_cases = {
"NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)),
"Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)),
"NonTrainable": lambda key: NonTrainable(jr.normal(key, (10,))),
"Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, (10,))),
"WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))),
"RealToIncreasingOnInterval": lambda key: RealToIncreasingOnInterval(
jnp.zeros(10), (-7, 5), min_width=0.2, include_endpoints="upper"
),
}


Expand Down