Skip to content

utils.arraylike_to_array fails for jax==0.8.2 due to isinstance(arr, ArrayLike) failure #230

@ThibeauWouters

Description

@ThibeauWouters

The release of jax v0.8.2 made the following change:

jax's Tracer no longer inherits from jax.Array at runtime. However,
jax.Array now uses a custom metaclass such isinstance(x, Array) is true
if an object x represents a traced Array. Only some Tracers represent
Arrays, so it is not correct for Tracer to inherit from Array.

For the moment, during Python type checking, we continue to declare Tracer
as a subclass of Array, however we expect to remove this in a future
release.

Though this seems to alter the behavior of isinstance(arr, ArrayLike), causing it to fail when arr is a BatchTracer (on Python 3.11). Therefore, utils.arraylike_to_array fails inside, e.g., vmap operations.

Here is a MWE with a slightly different implementiation that can fix the function (but might alter the behavior unexpectedly):

import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike, Array
import flowjax

jax.print_environment_info()

print(f"flowjax: {flowjax.__version__}\n")

def arraylike_to_array_OLD(
    arr: ArrayLike | None, err_name: str = "input", **kwargs
) -> Array:
    if not isinstance(arr, ArrayLike):
        raise TypeError(
            f"Expected {err_name} to be arraylike; got {type(arr).__name__}.",
        )
    return jnp.asarray(arr, **kwargs)

def arraylike_to_array_NEW(
    arr: ArrayLike | None, err_name: str = "input", **kwargs
) -> Array:
    try:
        return jnp.asarray(arr, **kwargs)
    except (TypeError, ValueError) as e:
        raise TypeError(
            f"Expected {err_name} to be arraylike; got {type(arr).__name__}.",
        )
     
print("Testing old implementation:")
try: 
    jax.vmap(arraylike_to_array_OLD)(jnp.array([[1.0, 2.0]]))
except Exception as e:
    print(f"Error occurred")
    print(e)
    
print("Testing new implementation:")
try: 
    jax.vmap(arraylike_to_array_NEW)(jnp.array([[1.0, 2.0]]))
    print("Executed successfully.")
except Exception as e:
    print(f"Error occurred")
    print(e)

Output:

jax:    0.8.2
jaxlib: 0.8.2
numpy:  2.4.0
python: 3.11.14 (main, Nov 19 2025, 23:12:58) [Clang 21.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='UU-MP63GHX43K', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:02:45 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T8112', machine='arm64')
flowjax: 17.2.1

Testing old implementation:
Error occurred
Expected input to be arraylike; got BatchTracer.
Testing new implementation:
Executed successfully.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions