Open
Description
I tried to run where
with chunked cupy
arrays:
In [1]: import xarray as xr
...: import cupy
...: import dask.array as da
...:
...: arr = xr.DataArray(cupy.arange(4), dims="x")
...: mask = xr.DataArray(cupy.array([False, True, True, False]), dims="x")
this works:
In [2]: arr.where(mask)
Out[2]:
<xarray.DataArray (x: 4)>
array([nan, 1., 2., nan])
Dimensions without coordinates: x
this fails:
In [4]: arr.chunk().where(mask).compute()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 arr.chunk().where(mask).compute()
File ~/repos/xarray/xarray/core/dataarray.py:1095, in DataArray.compute(self, **kwargs)
1076 """Manually trigger loading of this array's data from disk or a
1077 remote source into memory and return a new array. The original is
1078 left unaltered.
(...)
1092 dask.compute
1093 """
1094 new = self.copy(deep=False)
-> 1095 return new.load(**kwargs)
File ~/repos/xarray/xarray/core/dataarray.py:1069, in DataArray.load(self, **kwargs)
1051 def load(self: T_DataArray, **kwargs) -> T_DataArray:
1052 """Manually trigger loading of this array's data from disk or a
1053 remote source into memory and return this array.
1054
(...)
1067 dask.compute
1068 """
-> 1069 ds = self._to_temp_dataset().load(**kwargs)
1070 new = self._from_temp_dataset(ds)
1071 self._variable = new._variable
File ~/repos/xarray/xarray/core/dataset.py:752, in Dataset.load(self, **kwargs)
749 import dask.array as da
751 # evaluate all the dask arrays simultaneously
--> 752 evaluated_data = da.compute(*lazy_data.values(), **kwargs)
754 for k, data in zip(lazy_data, evaluated_data):
755 self.variables[k].data = data
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/base.py:600, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
597 keys.append(x.__dask_keys__())
598 postcomputes.append(x.__dask_postcompute__())
--> 600 results = schedule(dsk, keys, **kwargs)
601 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
86 elif isinstance(pool, multiprocessing.pool.Pool):
87 pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
90 pool.submit,
91 pool._max_workers,
92 dsk,
93 keys,
94 cache=cache,
95 get_id=_thread_get_id,
96 pack_exception=pack_exception,
97 **kwargs,
98 )
100 # Cleanup pools associated to dead threads
101 with pools_lock:
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
509 _execute_task(task, data) # Re-execute locally
510 else:
--> 511 raise_exception(exc, tb)
512 res, worker_id = loads(res_info)
513 state["cache"][key] = res
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:319, in reraise(exc, tb)
317 if exc.__traceback__ is not tb:
318 raise exc.with_traceback(tb)
--> 319 raise exc
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
222 try:
223 task, data = loads(task_info)
--> 224 result = _execute_task(task, data)
225 id = get_id()
226 result = dumps((result, id))
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
988 if not len(args) == len(self.inkeys):
989 raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:149, in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File <__array_function__ internals>:180, in where(*args, **kwargs)
File cupy/_core/core.pyx:1723, in cupy._core.core._ndarray_base.__array_function__()
File ~/.local/opt/mambaforge/envs/xarray/lib/python3.10/site-packages/cupy/_sorting/search.py:211, in where(condition, x, y)
209 if fusion._is_fusing():
210 return fusion._call_ufunc(_where_ufunc, condition, x, y)
--> 211 return _where_ufunc(condition.astype('?'), x, y)
File cupy/_core/_kernel.pyx:1287, in cupy._core._kernel.ufunc.__call__()
File cupy/_core/_kernel.pyx:160, in cupy._core._kernel._preprocess_args()
File cupy/_core/_kernel.pyx:146, in cupy._core._kernel._preprocess_arg()
TypeError: Unsupported type <class 'numpy.ndarray'>
this works again:
In [7]: arr.chunk().where(mask.chunk(), cupy.array(cupy.nan)).compute()
Out[7]:
<xarray.DataArray (x: 4)>
array([nan, 1., 2., nan])
Dimensions without coordinates: x
And other methods like fillna
show similar behavior.
I think the reason is that this:
xarray/xarray/core/duck_array_ops.py
Line 195 in d4db166
cupy
beneath other layers of duckarrays (most commonly dask
, pint
, or both). In this specific case we could extend the condition to also match chunked cupy
arrays (like arr.cupy.is_cupy
does, but using is_duck_dask_array
), but this will still break for other duckarray layers or if dask
is not involved, and we're also in the process of moving away from special-casing dask
. So short of asking cupy
to treat 0d arrays like scalars I'm not sure how to fix this.