Skip to content

as_shared_dtype converts scalars to 0d numpy arrays if chunked cupy is involved #7721

Open
@keewis

Description

@keewis

I tried to run where with chunked cupy arrays:

In [1]: import xarray as xr
   ...: import cupy
   ...: import dask.array as da
   ...: 
   ...: arr = xr.DataArray(cupy.arange(4), dims="x")
   ...: mask = xr.DataArray(cupy.array([False, True, True, False]), dims="x")

this works:

In [2]: arr.where(mask)
Out[2]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

this fails:

In [4]: arr.chunk().where(mask).compute()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 arr.chunk().where(mask).compute()

File ~/repos/xarray/xarray/core/dataarray.py:1095, in DataArray.compute(self, **kwargs)
   1076 """Manually trigger loading of this array's data from disk or a
   1077 remote source into memory and return a new array. The original is
   1078 left unaltered.
   (...)
   1092 dask.compute
   1093 """
   1094 new = self.copy(deep=False)
-> 1095 return new.load(**kwargs)

File ~/repos/xarray/xarray/core/dataarray.py:1069, in DataArray.load(self, **kwargs)
   1051 def load(self: T_DataArray, **kwargs) -> T_DataArray:
   1052     """Manually trigger loading of this array's data from disk or a
   1053     remote source into memory and return this array.
   1054 
   (...)
   1067     dask.compute
   1068     """
-> 1069     ds = self._to_temp_dataset().load(**kwargs)
   1070     new = self._from_temp_dataset(ds)
   1071     self._variable = new._variable

File ~/repos/xarray/xarray/core/dataset.py:752, in Dataset.load(self, **kwargs)
    749 import dask.array as da
    751 # evaluate all the dask arrays simultaneously
--> 752 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
    754 for k, data in zip(lazy_data, evaluated_data):
    755     self.variables[k].data = data

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    597     keys.append(x.__dask_keys__())
    598     postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
    601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
    988 if not len(args) == len(self.inkeys):
    989     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File <__array_function__ internals>:180, in where(*args, **kwargs)

File cupy/_core/core.pyx:1723, in cupy._core.core._ndarray_base.__array_function__()

File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/cupy/_sorting/search.py:211, in where(condition, x, y)
    209 if fusion._is_fusing():
    210     return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)

File cupy/_core/_kernel.pyx:1287, in cupy._core._kernel.ufunc.__call__()

File cupy/_core/_kernel.pyx:160, in cupy._core._kernel._preprocess_args()

File cupy/_core/_kernel.pyx:146, in cupy._core._kernel._preprocess_arg()

TypeError: Unsupported type <class 'numpy.ndarray'>

this works again:

In [7]: arr.chunk().where(mask.chunk(), cupy.array(cupy.nan)).compute()
Out[7]: 
<xarray.DataArray (x: 4)>
array([nan,  1.,  2., nan])
Dimensions without coordinates: x

And other methods like fillna show similar behavior.

I think the reason is that this:

if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
is not sufficient to detect cupy beneath other layers of duckarrays (most commonly dask, pint, or both). In this specific case we could extend the condition to also match chunked cupy arrays (like arr.cupy.is_cupy does, but using is_duck_dask_array), but this will still break for other duckarray layers or if dask is not involved, and we're also in the process of moving away from special-casing dask. So short of asking cupy to treat 0d arrays like scalars I'm not sure how to fix this.

cc @jacobtomlinson

Metadata

Metadata

Assignees

No one assigned

    Labels

    topic-arraysrelated to flexible array support

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions