-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
The release of jax v0.8.2 made the following change:
jax's Tracer no longer inherits from jax.Array at runtime. However,
jax.Array now uses a custom metaclass such isinstance(x, Array) is true
if an object x represents a traced Array. Only some Tracers represent
Arrays, so it is not correct for Tracer to inherit from Array.
For the moment, during Python type checking, we continue to declare Tracer
as a subclass of Array, however we expect to remove this in a future
release.
Though this seems to alter the behavior of isinstance(arr, ArrayLike), causing it to fail when arr is a BatchTracer (on Python 3.11). Therefore, utils.arraylike_to_array fails inside, e.g., vmap operations.
Here is a MWE with a slightly different implementiation that can fix the function (but might alter the behavior unexpectedly):
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike, Array
import flowjax
jax.print_environment_info()
print(f"flowjax: {flowjax.__version__}\n")
def arraylike_to_array_OLD(
arr: ArrayLike | None, err_name: str = "input", **kwargs
) -> Array:
if not isinstance(arr, ArrayLike):
raise TypeError(
f"Expected {err_name} to be arraylike; got {type(arr).__name__}.",
)
return jnp.asarray(arr, **kwargs)
def arraylike_to_array_NEW(
arr: ArrayLike | None, err_name: str = "input", **kwargs
) -> Array:
try:
return jnp.asarray(arr, **kwargs)
except (TypeError, ValueError) as e:
raise TypeError(
f"Expected {err_name} to be arraylike; got {type(arr).__name__}.",
)
print("Testing old implementation:")
try:
jax.vmap(arraylike_to_array_OLD)(jnp.array([[1.0, 2.0]]))
except Exception as e:
print(f"Error occurred")
print(e)
print("Testing new implementation:")
try:
jax.vmap(arraylike_to_array_NEW)(jnp.array([[1.0, 2.0]]))
print("Executed successfully.")
except Exception as e:
print(f"Error occurred")
print(e)Output:
jax: 0.8.2
jaxlib: 0.8.2
numpy: 2.4.0
python: 3.11.14 (main, Nov 19 2025, 23:12:58) [Clang 21.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='UU-MP63GHX43K', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:02:45 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T8112', machine='arm64')
flowjax: 17.2.1
Testing old implementation:
Error occurred
Expected input to be arraylike; got BatchTracer.
Testing new implementation:
Executed successfully.Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels