Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
python3 -m pylint $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $?
- name: Lint check quote consistency
run: |
python3 -m pylint $(find optax -name '*.py' | grep -v 'test.py' | xargs) \
python3 -m pylint $(find optax -name '*.py' | xargs) \
--disable=R,C,W,E --enable=inconsistent-quotes --check-quote-consistency=y
build-and-pytype:
needs: [pre-commit, flake8, pylint, ruff-lint] # do not run tests if linting fails
Expand Down
31 changes: 24 additions & 7 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@
)


def _get_opt(self: absltest.TestCase, opt_name: str):
if opt_name == 'optimistic_adam':
opt_ = getattr(alias, opt_name)

@functools.wraps(opt_)
def opt(*args, **kwargs):
with self.assertWarnsRegex(
DeprecationWarning, 'use `optimistic_adam_v2` instead'
):
return opt_(*args, **kwargs)

return opt

return getattr(alias, opt_name)


def _setup_parabola(dtype):
"""Quadratic function as an optimization target."""
initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype)
Expand Down Expand Up @@ -162,7 +178,7 @@ def test_optimization(self, opt_name, opt_kwargs, target, dtype):
' Rosenbrockfunction'
)

opt = getattr(alias, opt_name)(**opt_kwargs)
opt = _get_opt(self, opt_name)(**opt_kwargs)
initial_params, final_params, objective = target(dtype)

@jax.jit
Expand Down Expand Up @@ -217,7 +233,7 @@ def step(params, state):

@parameterized.product(_OPTIMIZERS_UNDER_TEST)
def test_optimizers_accept_extra_args(self, opt_name, opt_kwargs):
opt = getattr(alias, opt_name)(**opt_kwargs)
opt = _get_opt(self, opt_name)(**opt_kwargs)
# intentionally ommit: opt = base.with_extra_args_support(opt)
initial_params, _, objective = _setup_rosenbrock(jnp.float32)

Expand All @@ -244,7 +260,7 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
):
"""Checks that optimizers can be wrapped in inject_hyperparams."""
# See also https://github.com/google-deepmind/optax/issues/412.
opt_factory = getattr(alias, opt_name)
opt_factory = _get_opt(self, opt_name)
opt = opt_factory(**opt_kwargs)
if opt_name == 'adafactor':
# Adafactor wrapped in inject_hyperparams currently needs a static
Expand Down Expand Up @@ -292,7 +308,7 @@ def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
opt = alias.sgd(0.1, momentum=0.9, accumulator_dtype=state_dtype)
attribute_name = 'trace'
elif opt_name in ['adam', 'adamw']:
opt = getattr(alias, opt_name)(0.1, mu_dtype=state_dtype)
opt = _get_opt(self, opt_name)(0.1, mu_dtype=state_dtype)
attribute_name = 'mu'
else:
raise ValueError(f'Unsupported optimizer: {opt_name}')
Expand Down Expand Up @@ -321,7 +337,7 @@ def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree.bias_correction).
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
opt_factory = _get_opt(self, opt_name)
opt = opt_factory(**opt_kwargs)
fun = lambda x: jnp.sum(x**2)

Expand Down Expand Up @@ -352,7 +368,7 @@ def test_state_shape_dtype_shard_stability(self, opt_name, opt_kwargs, dtype):
)

with utils.x64_precision(dtype in (jnp.float64, jnp.complex128)):
opt = getattr(alias, opt_name)(**opt_kwargs)
opt = _get_opt(self, opt_name)(**opt_kwargs)
initial_params, _, objective = _setup_parabola(dtype)

@jax.jit
Expand Down Expand Up @@ -397,7 +413,8 @@ def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks for issues like https://github.com/google-deepmind/optax/issues/377
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
opt_factory = _get_opt(self, opt_name)

base_opt = opt_factory(**opt_kwargs)
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

Expand Down
3 changes: 2 additions & 1 deletion optax/experimental/microbatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class Accumulator:

def _with_floating_check(fn: Function) -> Function:
def wrapper(*args, **kwargs):
dtypes, _ = jax.tree.flatten(jax.tree.map(jnp.dtype, (args, kwargs)))
dtypes, _ = jax.tree.flatten(
jax.tree.map(lambda x: x.dtype, (args, kwargs)))
if not all(jnp.issubdtype(dtype, jnp.floating) for dtype in dtypes):
raise ValueError(
'MEAN and RUNNING_MEAN Accumulators require floating-point values.'
Expand Down
4 changes: 3 additions & 1 deletion optax/losses/_classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,9 @@ def test_axis(self, shape, axis):
np.testing.assert_allclose(x, y, atol=1e-4)

def test_deprecated_alias(self):
x = _classification.convex_kl_divergence(self.log_ps[0], self.qs[0])
with self.assertWarnsRegex(DeprecationWarning,
'use generalized_kl_divergence'):
x = _classification.convex_kl_divergence(self.log_ps[0], self.qs[0])
y = _classification.generalized_kl_divergence(self.log_ps[0], self.qs[0])
np.testing.assert_allclose(x, y, atol=1e-4)

Expand Down
2 changes: 1 addition & 1 deletion optax/tree_utils/_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def tree_dtype(
if not leaves:
# If the tree is empty, we return the default dtype as given by JAX on
# empty lists.
return jnp.dtype(jnp.asarray(leaves))
return jnp.asarray(leaves).dtype
if mixed_dtype_handler is None:
dtype = jnp.asarray(leaves[0]).dtype
_tree_assert_all_dtypes_equal(tree, dtype)
Expand Down
Loading