Skip to content

Commit 0405ff5

Browse files
committed
Simplify Dask
1 parent 2e97b6f commit 0405ff5

File tree

1 file changed

+35
-164
lines changed

1 file changed

+35
-164
lines changed

src/array_api_extra/_apply.py

Lines changed: 35 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6-
from collections.abc import Callable, Hashable, Mapping, Sequence
6+
from collections.abc import Callable, Sequence
77
from functools import wraps
88
from types import ModuleType
99
from typing import TYPE_CHECKING, Any, cast
@@ -20,9 +20,8 @@
2020
from typing import TypeAlias
2121

2222
import numpy as np
23-
import numpy.typing as npt
2423

25-
NumPyObject: TypeAlias = npt.NDArray[DType] | np.generic # type: ignore[no-any-explicit]
24+
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit]
2625

2726

2827
def apply_numpy_func( # type: ignore[no-any-explicit]
@@ -31,11 +30,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3130
shapes: Sequence[tuple[int, ...]] | None = None,
3231
dtypes: Sequence[DType] | None = None,
3332
xp: ModuleType | None = None,
34-
input_indices: Sequence[Sequence[Hashable]] | None = None,
35-
core_indices: Sequence[Hashable] | None = None,
36-
output_indices: Sequence[Sequence[Hashable]] | None = None,
37-
adjust_chunks: Sequence[dict[Hashable, Callable[[int], int]]] | None = None,
38-
new_axes: Sequence[dict[Hashable, int]] | None = None,
3933
**kwargs: Any,
4034
) -> tuple[Array, ...]:
4135
"""
@@ -66,33 +60,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
6660
Default: infer the result type(s) from the input arrays.
6761
xp : array_namespace, optional
6862
The standard-compatible namespace for `args`. Default: infer.
69-
input_indices : Sequence[Sequence[Hashable]], optional
70-
Dask specific.
71-
Axes labels for each input array, e.g. if there are two args with respectively
72-
ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
73-
(1,)]``.
74-
Default: disallow Dask.
75-
core_indices : Sequence[Hashable], optional
76-
**Dask specific.**
77-
Axes of the input arrays that cannot be broken into chunks.
78-
Default: disallow Dask.
79-
output_indices : Sequence[Sequence[Hashable]], optional
80-
**Dask specific.**
81-
Axes labels for each output array. If `func` returns a single (non-sequence)
82-
output, this must be a sequence containing a single sequence of labels, e.g.
83-
``['ijk']``.
84-
Default: disallow Dask.
85-
adjust_chunks : Sequence[Mapping[Hashable, Callable[[int], int]]], optional
86-
**Dask specific.**
87-
Sequence of dicts, one per output, mapping index to function to be applied to
88-
each chunk to determine the output size. The total must add up to the output
89-
shape.
90-
Default: on Dask, the size along each index cannot change.
91-
new_axes : Sequence[Mapping[Hashable, int]], optional
92-
**Dask specific.**
93-
New indexes and their dimension lengths, one per output.
94-
Default: on Dask, there can't be `output_indices` that don't appear in
95-
`input_indices`.
9663
**kwargs : Any, optional
9764
Additional keyword arguments to pass verbatim to `func`.
9865
Any array objects in them won't be converted to NumPy.
@@ -124,43 +91,22 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
12491
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
12592
12693
Dask
127-
This allows applying eager functions to the individual chunks of dask arrays.
128-
The dask graph won't be computed. As a special limitation, `func` must return
129-
exactly one output.
94+
This allows applying eager functions to dask arrays.
95+
The dask graph won't be computed.
13096
131-
In order to enable running on Dask you need to specify at least
132-
`input_indices`, `output_indices`, and `core_indices`, but you may also need
133-
`adjust_chunks` and `new_axes` depending on the function.
97+
`apply_numpy_func` doesn't know if `func` reduces along any axes and shape
98+
changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
99+
will be rechunked into a single chunk.
134100
135-
Read `dask.array.blockwise`:
136-
- ``input_indices`` map to the even ``*args`` of `dask.array.blockwise`
137-
- ``output_indices[0]`` maps to the ``out_ind`` parameter
138-
- ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter
139-
- ``new_axes[0]`` maps to the ``new_axes`` parameter
101+
.. warning::
140102
141-
``core_indices`` is a safety measure to prevent incorrect results on
142-
Dask along chunked axes. Consider this::
103+
The whole operation needs to fit in memory all at once on a single worker.
143104
144-
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
145-
... input_indices=['ij'], output_indices=['ij'])
146-
147-
The above example would produce incorrect results if x is a dask array with more
148-
than one chunk along axis 0, as each chunk will calculate its own local
149-
subtotal. To prevent this, we need to declare the first axis of ``args[0]`` as a
150-
*core axis*::
151-
152-
>>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
153-
... input_indices=['ij'], output_indices=['ij'],
154-
... core_indices='i')
155-
156-
This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
157-
along multiple chunks, thus forcing the final user to rechunk ahead of time:
158-
159-
>>> x = x.chunk({0: -1})
160-
161-
This needs to always be a conscious decision on behalf of the final user, as the
162-
new chunks will be larger than the old and may cause memory issues, unless chunk
163-
size is reduced along a different, non-core axis.
105+
The outputs will also be returned as a single chunk and you should consider
106+
rechunking them into smaller chunks afterwards.
107+
If you want to distribute the calculation across multiple workers, you
108+
should use `dask.array.map_blocks`, `dask.array.blockwise`,
109+
`dask.array.map_overlap`, or a native Dask wrapper instead of this function.
164110
"""
165111
if xp is None:
166112
xp = array_namespace(*args)
@@ -177,68 +123,30 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177123
raise ValueError(msg)
178124

179125
if is_dask_namespace(xp):
180-
# General validation
181-
if len(shapes) > 1:
182-
msg = "dask.array.map_blocks() does not support multiple outputs"
183-
raise NotImplementedError(msg)
184-
if input_indices is None or output_indices is None or core_indices is None:
185-
msg = (
186-
"Dask is disallowed unless one declares input_indices, "
187-
"output_indices, and core_indices"
188-
)
189-
raise ValueError(msg)
190-
if len(input_indices) != len(args):
191-
msg = f"got {len(input_indices)} input_indices and {len(args)} args"
192-
raise ValueError(msg)
193-
if len(output_indices) != len(shapes):
194-
msg = f"got {len(output_indices)} input_indices and {len(shapes)} shapes"
195-
raise NotImplementedError(msg)
196-
if isinstance(adjust_chunks, Mapping):
197-
msg = "adjust_chunks must be a sequence of mappings"
198-
raise ValueError(msg)
199-
if adjust_chunks is not None and len(adjust_chunks) != len(shapes):
200-
msg = f"got {len(adjust_chunks)} adjust_chunks and {len(shapes)} shapes"
201-
raise ValueError(msg)
202-
if isinstance(new_axes, Mapping):
203-
msg = "new_axes must be a sequence of mappings"
204-
raise ValueError(msg)
205-
if new_axes is not None and len(new_axes) != len(shapes):
206-
msg = f"got {len(new_axes)} new_axes and {len(shapes)} shapes"
207-
raise ValueError(msg)
126+
import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127+
128+
metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access
129+
meta_xp = array_namespace(*metas)
130+
meta = metas[0]
208131

209-
# core_indices validation
210-
for inp_idx, arg in zip(input_indices, args, strict=True):
211-
for i, chunks in zip(inp_idx, arg.chunks, strict=True):
212-
if i in core_indices and len(chunks) > 1:
213-
msg = f"Core index {i} is broken into multiple chunks"
214-
raise ValueError(msg)
215-
216-
meta_xp = array_namespace(*(getattr(arg, "meta", None) for arg in args))
217-
wrapped = _npfunc_single_output_wrapper(func, meta_xp)
218-
dask_args = []
219-
for arg, inp_idx in zip(args, input_indices, strict=True):
220-
dask_args += [arg, inp_idx]
221-
222-
out = xp.blockwise(
223-
wrapped,
224-
output_indices[0],
225-
*dask_args,
226-
dtype=dtypes[0],
227-
adjust_chunks=adjust_chunks[0] if adjust_chunks is not None else None,
228-
new_axes=new_axes[0] if new_axes is not None else None,
229-
**kwargs,
132+
wrapped = dask.delayed(_npfunc_wrapper(func, meta_xp), pure=True)
133+
# This finalizes each arg, which is the same as arg.rechunk(-1)
134+
# Please read docstring above for why we're not using
135+
# dask.array.map_blocks or dask.array.blockwise!
136+
delayed_out = wrapped(*args, **kwargs)
137+
138+
return tuple(
139+
xp.from_delayed(delayed_out[i], shape=shape, dtype=dtype, meta=meta)
140+
for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True))
230141
)
231-
if out.shape != shapes[0]:
232-
msg = f"expected shape {shapes[0]}, but got {out.shape} from indices"
233-
raise ValueError(msg)
234-
return (out,)
235142

236-
wrapped = _npfunc_tuple_output_wrapper(func, xp)
143+
wrapped = _npfunc_wrapper(func, xp)
237144
if is_jax_namespace(xp):
238145
# If we're inside jax.jit, we can't eagerly convert
239146
# the JAX tracer objects to numpy.
240147
# Instead, we delay calling wrapped, which will receive
241148
# as arguments and will return JAX eager arrays.
149+
242150
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
243151

244152
return cast(
@@ -271,17 +179,17 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
271179
return out # type: ignore[no-any-return]
272180

273181

274-
def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
182+
def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
275183
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
276184
xp: ModuleType,
277185
) -> Callable[..., tuple[Array, ...]]:
278186
"""
279187
Helper of `apply_numpy_func`.
280188
281189
Given a function that accepts one or more numpy arrays as positional arguments and
282-
returns a single numpy array or a sequence of numpy arrays,
283-
return a function that accepts the same number of Array API arrays and always
284-
returns a tuple of Array API array.
190+
returns a single numpy array or a sequence of numpy arrays, return a function that
191+
accepts the same number of Array API arrays and always returns a tuple of Array API
192+
array.
285193
286194
Any keyword arguments are passed through verbatim to the wrapped function.
287195
@@ -290,6 +198,7 @@ def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc i
290198
densification for sparse arrays, device->host transfer for cupy and torch arrays).
291199
"""
292200

201+
# On Dask, @wraps causes the graph key to contain the wrapped function's name
293202
@wraps(func)
294203
def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
295204
*args: Array, **kwargs: Any
@@ -311,41 +220,3 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
311220
return tuple(xp.asarray(o) for o in out)
312221

313222
return wrapper
314-
315-
316-
def _npfunc_single_output_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
317-
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
318-
xp: ModuleType,
319-
) -> Callable[..., Array]:
320-
"""
321-
Dask-specific helper of `apply_numpy_func`.
322-
323-
Variant of `_npfunc_tuple_output_wrapper`, to be used with Dask which, at the time
324-
of writing, does not support multiple outputs in `dask.array.blockwise`.
325-
326-
func may return a single numpy object or a sequence with exactly one numpy object.
327-
The wrapper returns a single Array object, with no tuple wrapping.
328-
"""
329-
330-
# @wraps causes the generated dask key to contain the name of the wrapped function
331-
@wraps(func)
332-
def wrapper( # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
333-
*args: Array, **kwargs: Any
334-
) -> Array:
335-
import numpy as np # pylint: disable=import-outside-toplevel
336-
337-
args = tuple(np.asarray(arg) for arg in args)
338-
out = func(*args, **kwargs)
339-
340-
if not isinstance(out, np.ndarray | np.generic):
341-
if not isinstance(out, Sequence) or len(out) != 1: # pyright: ignore[reportUnnecessaryIsInstance]
342-
msg = (
343-
"apply_numpy_func: func must return a single numpy object or a "
344-
f"sequence with exactly one numpy object; got {out}"
345-
)
346-
raise ValueError(msg)
347-
out = out[0]
348-
349-
return xp.asarray(out)
350-
351-
return wrapper

0 commit comments

Comments
 (0)