Skip to content

Conditionally run health checks on jitted JAX arrays and dask arrays #225

Closed
@crusaderky

Description

@crusaderky

I've found cases of functions in scipy, e.g.

https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/cluster/vq.py#L332-L334

that crash on the very first few lines with JAX when they're inside @jax.jit:

>>> import jax
>>> import jax.numpy as xp
>>> from scipy.cluster.vq import kmeans
>>> a = xp.asarray([[1.,2.],[3.,4.]])
>>> jax.jit(kmeans)(a, 2)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function kmeans at /home/crusaderky/github/scipy/build-install/lib/python3.12/site-packages/scipy/cluster/vq.py:332 for jit. This concrete value was not available in Python because it depends on the value of the argument obs.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The issue is that there is a health check enabled by default, check_finite=True, which triggers this code:

https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/_lib/_array_api.py#L104-L108

A JAX jitted array crashes on bool(); a dask array is quietly computed when you do so - which is possibly even worse.

There are two issues here:

  1. the default behaviour of the function is to inspect the contents of the array, and
  2. the error message is uncomprehensible to an end user, as it is triggered by code deep inside the scipy implementation.

My proposal:

  1. in array-api-compat, add two functions:
def is_jax_jitted_array(x):
    return isinstance(x, DynamicJaxprTracer)

def is_material_array(x):
    """Return True if x has contents in memory at the moment of calling this function,
    which are cheap to retrieve as long as they're small in size.
    Return False if x is a future or it would be otherwise impossible or expensive to
    read its contents, regardless of their size.
    """
    return not is_dask_array(x) and not is_jax_jitted_array(x)
  1. in scipy, change kmeans(..., check_finite=True) to kmeans(..., check_finite=None), which will mean "check if possible", and replace
if check_finite:
    _check_finite(x, xp)

with

def _check_material_array(x: Array, check: bool | None, check_name: str) -> bool:
    if check is None:
        return is_material_array(x)
    if check and not is_material_array(x):
        raise TypeError(f"Can't check non-material array {type(x)}. Please set {check_name} to None or False.")
    return check

...

if _check_material_array(x, check_finite, "check_finite"):
    _check_finite(x, xp)

However, @jakevdp mentioned elsewhere that DynamicJaxprTracer is not part of the public API of JAX and there is no public method to test for jitting. Not sure I can see a way forward without this information.

Also CC @rgommers @lucascolley

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions