3
3
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
4
4
from __future__ import annotations
5
5
6
- from collections .abc import Callable , Hashable , Mapping , Sequence
6
+ from collections .abc import Callable , Sequence
7
7
from functools import wraps
8
8
from types import ModuleType
9
9
from typing import TYPE_CHECKING , Any , cast
20
20
from typing import TypeAlias
21
21
22
22
import numpy as np
23
- import numpy .typing as npt
24
23
25
- NumPyObject : TypeAlias = npt . NDArray [ DType ] | np .generic # type: ignore[no-any-explicit]
24
+ NumPyObject : TypeAlias = np . ndarray [ Any , Any ] | np .generic # type: ignore[no-any-explicit]
26
25
27
26
28
27
def apply_numpy_func ( # type: ignore[no-any-explicit]
@@ -31,11 +30,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
31
30
shapes : Sequence [tuple [int , ...]] | None = None ,
32
31
dtypes : Sequence [DType ] | None = None ,
33
32
xp : ModuleType | None = None ,
34
- input_indices : Sequence [Sequence [Hashable ]] | None = None ,
35
- core_indices : Sequence [Hashable ] | None = None ,
36
- output_indices : Sequence [Sequence [Hashable ]] | None = None ,
37
- adjust_chunks : Sequence [dict [Hashable , Callable [[int ], int ]]] | None = None ,
38
- new_axes : Sequence [dict [Hashable , int ]] | None = None ,
39
33
** kwargs : Any ,
40
34
) -> tuple [Array , ...]:
41
35
"""
@@ -66,33 +60,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
66
60
Default: infer the result type(s) from the input arrays.
67
61
xp : array_namespace, optional
68
62
The standard-compatible namespace for `args`. Default: infer.
69
- input_indices : Sequence[Sequence[Hashable]], optional
70
- Dask specific.
71
- Axes labels for each input array, e.g. if there are two args with respectively
72
- ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
73
- (1,)]``.
74
- Default: disallow Dask.
75
- core_indices : Sequence[Hashable], optional
76
- **Dask specific.**
77
- Axes of the input arrays that cannot be broken into chunks.
78
- Default: disallow Dask.
79
- output_indices : Sequence[Sequence[Hashable]], optional
80
- **Dask specific.**
81
- Axes labels for each output array. If `func` returns a single (non-sequence)
82
- output, this must be a sequence containing a single sequence of labels, e.g.
83
- ``['ijk']``.
84
- Default: disallow Dask.
85
- adjust_chunks : Sequence[Mapping[Hashable, Callable[[int], int]]], optional
86
- **Dask specific.**
87
- Sequence of dicts, one per output, mapping index to function to be applied to
88
- each chunk to determine the output size. The total must add up to the output
89
- shape.
90
- Default: on Dask, the size along each index cannot change.
91
- new_axes : Sequence[Mapping[Hashable, int]], optional
92
- **Dask specific.**
93
- New indexes and their dimension lengths, one per output.
94
- Default: on Dask, there can't be `output_indices` that don't appear in
95
- `input_indices`.
96
63
**kwargs : Any, optional
97
64
Additional keyword arguments to pass verbatim to `func`.
98
65
Any array objects in them won't be converted to NumPy.
@@ -124,43 +91,22 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
124
91
<https://sparse.pydata.org/en/stable/operations.html#package-configuration>`_.
125
92
126
93
Dask
127
- This allows applying eager functions to the individual chunks of dask arrays.
128
- The dask graph won't be computed. As a special limitation, `func` must return
129
- exactly one output.
94
+ This allows applying eager functions to dask arrays.
95
+ The dask graph won't be computed.
130
96
131
- In order to enable running on Dask you need to specify at least
132
- `input_indices`, `output_indices`, and `core_indices`, but you may also need
133
- `adjust_chunks` and `new_axes` depending on the function .
97
+ `apply_numpy_func` doesn't know if `func` reduces along any axes and shape
98
+ changes are non-trivial in chunked Dask arrays. For these reasons, all inputs
99
+ will be rechunked into a single chunk .
134
100
135
- Read `dask.array.blockwise`:
136
- - ``input_indices`` map to the even ``*args`` of `dask.array.blockwise`
137
- - ``output_indices[0]`` maps to the ``out_ind`` parameter
138
- - ``adjust_chunks[0]`` maps to the ``adjust_chunks`` parameter
139
- - ``new_axes[0]`` maps to the ``new_axes`` parameter
101
+ .. warning::
140
102
141
- ``core_indices`` is a safety measure to prevent incorrect results on
142
- Dask along chunked axes. Consider this::
103
+ The whole operation needs to fit in memory all at once on a single worker.
143
104
144
- >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
145
- ... input_indices=['ij'], output_indices=['ij'])
146
-
147
- The above example would produce incorrect results if x is a dask array with more
148
- than one chunk along axis 0, as each chunk will calculate its own local
149
- subtotal. To prevent this, we need to declare the first axis of ``args[0]`` as a
150
- *core axis*::
151
-
152
- >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
153
- ... input_indices=['ij'], output_indices=['ij'],
154
- ... core_indices='i')
155
-
156
- This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
157
- along multiple chunks, thus forcing the final user to rechunk ahead of time:
158
-
159
- >>> x = x.chunk({0: -1})
160
-
161
- This needs to always be a conscious decision on behalf of the final user, as the
162
- new chunks will be larger than the old and may cause memory issues, unless chunk
163
- size is reduced along a different, non-core axis.
105
+ The outputs will also be returned as a single chunk and you should consider
106
+ rechunking them into smaller chunks afterwards.
107
+ If you want to distribute the calculation across multiple workers, you
108
+ should use `dask.array.map_blocks`, `dask.array.blockwise`,
109
+ `dask.array.map_overlap`, or a native Dask wrapper instead of this function.
164
110
"""
165
111
if xp is None :
166
112
xp = array_namespace (* args )
@@ -177,68 +123,30 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177
123
raise ValueError (msg )
178
124
179
125
if is_dask_namespace (xp ):
180
- # General validation
181
- if len (shapes ) > 1 :
182
- msg = "dask.array.map_blocks() does not support multiple outputs"
183
- raise NotImplementedError (msg )
184
- if input_indices is None or output_indices is None or core_indices is None :
185
- msg = (
186
- "Dask is disallowed unless one declares input_indices, "
187
- "output_indices, and core_indices"
188
- )
189
- raise ValueError (msg )
190
- if len (input_indices ) != len (args ):
191
- msg = f"got { len (input_indices )} input_indices and { len (args )} args"
192
- raise ValueError (msg )
193
- if len (output_indices ) != len (shapes ):
194
- msg = f"got { len (output_indices )} input_indices and { len (shapes )} shapes"
195
- raise NotImplementedError (msg )
196
- if isinstance (adjust_chunks , Mapping ):
197
- msg = "adjust_chunks must be a sequence of mappings"
198
- raise ValueError (msg )
199
- if adjust_chunks is not None and len (adjust_chunks ) != len (shapes ):
200
- msg = f"got { len (adjust_chunks )} adjust_chunks and { len (shapes )} shapes"
201
- raise ValueError (msg )
202
- if isinstance (new_axes , Mapping ):
203
- msg = "new_axes must be a sequence of mappings"
204
- raise ValueError (msg )
205
- if new_axes is not None and len (new_axes ) != len (shapes ):
206
- msg = f"got { len (new_axes )} new_axes and { len (shapes )} shapes"
207
- raise ValueError (msg )
126
+ import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127
+
128
+ metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
129
+ meta_xp = array_namespace (* metas )
130
+ meta = metas [0 ]
208
131
209
- # core_indices validation
210
- for inp_idx , arg in zip (input_indices , args , strict = True ):
211
- for i , chunks in zip (inp_idx , arg .chunks , strict = True ):
212
- if i in core_indices and len (chunks ) > 1 :
213
- msg = f"Core index { i } is broken into multiple chunks"
214
- raise ValueError (msg )
215
-
216
- meta_xp = array_namespace (* (getattr (arg , "meta" , None ) for arg in args ))
217
- wrapped = _npfunc_single_output_wrapper (func , meta_xp )
218
- dask_args = []
219
- for arg , inp_idx in zip (args , input_indices , strict = True ):
220
- dask_args += [arg , inp_idx ]
221
-
222
- out = xp .blockwise (
223
- wrapped ,
224
- output_indices [0 ],
225
- * dask_args ,
226
- dtype = dtypes [0 ],
227
- adjust_chunks = adjust_chunks [0 ] if adjust_chunks is not None else None ,
228
- new_axes = new_axes [0 ] if new_axes is not None else None ,
229
- ** kwargs ,
132
+ wrapped = dask .delayed (_npfunc_wrapper (func , meta_xp ), pure = True )
133
+ # This finalizes each arg, which is the same as arg.rechunk(-1)
134
+ # Please read docstring above for why we're not using
135
+ # dask.array.map_blocks or dask.array.blockwise!
136
+ delayed_out = wrapped (* args , ** kwargs )
137
+
138
+ return tuple (
139
+ xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = meta )
140
+ for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
230
141
)
231
- if out .shape != shapes [0 ]:
232
- msg = f"expected shape { shapes [0 ]} , but got { out .shape } from indices"
233
- raise ValueError (msg )
234
- return (out ,)
235
142
236
- wrapped = _npfunc_tuple_output_wrapper (func , xp )
143
+ wrapped = _npfunc_wrapper (func , xp )
237
144
if is_jax_namespace (xp ):
238
145
# If we're inside jax.jit, we can't eagerly convert
239
146
# the JAX tracer objects to numpy.
240
147
# Instead, we delay calling wrapped, which will receive
241
148
# as arguments and will return JAX eager arrays.
149
+
242
150
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
243
151
244
152
return cast (
@@ -271,17 +179,17 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
271
179
return out # type: ignore[no-any-return]
272
180
273
181
274
- def _npfunc_tuple_output_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
182
+ def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
275
183
func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
276
184
xp : ModuleType ,
277
185
) -> Callable [..., tuple [Array , ...]]:
278
186
"""
279
187
Helper of `apply_numpy_func`.
280
188
281
189
Given a function that accepts one or more numpy arrays as positional arguments and
282
- returns a single numpy array or a sequence of numpy arrays,
283
- return a function that accepts the same number of Array API arrays and always
284
- returns a tuple of Array API array.
190
+ returns a single numpy array or a sequence of numpy arrays, return a function that
191
+ accepts the same number of Array API arrays and always returns a tuple of Array API
192
+ array.
285
193
286
194
Any keyword arguments are passed through verbatim to the wrapped function.
287
195
@@ -290,6 +198,7 @@ def _npfunc_tuple_output_wrapper( # type: ignore[no-any-explicit] # numpydoc i
290
198
densification for sparse arrays, device->host transfer for cupy and torch arrays).
291
199
"""
292
200
201
+ # On Dask, @wraps causes the graph key to contain the wrapped function's name
293
202
@wraps (func )
294
203
def wrapper ( # type: ignore[no-any-decorated,no-any-explicit]
295
204
* args : Array , ** kwargs : Any
@@ -311,41 +220,3 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
311
220
return tuple (xp .asarray (o ) for o in out )
312
221
313
222
return wrapper
314
-
315
-
316
- def _npfunc_single_output_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
317
- func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
318
- xp : ModuleType ,
319
- ) -> Callable [..., Array ]:
320
- """
321
- Dask-specific helper of `apply_numpy_func`.
322
-
323
- Variant of `_npfunc_tuple_output_wrapper`, to be used with Dask which, at the time
324
- of writing, does not support multiple outputs in `dask.array.blockwise`.
325
-
326
- func may return a single numpy object or a sequence with exactly one numpy object.
327
- The wrapper returns a single Array object, with no tuple wrapping.
328
- """
329
-
330
- # @wraps causes the generated dask key to contain the name of the wrapped function
331
- @wraps (func )
332
- def wrapper ( # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08
333
- * args : Array , ** kwargs : Any
334
- ) -> Array :
335
- import numpy as np # pylint: disable=import-outside-toplevel
336
-
337
- args = tuple (np .asarray (arg ) for arg in args )
338
- out = func (* args , ** kwargs )
339
-
340
- if not isinstance (out , np .ndarray | np .generic ):
341
- if not isinstance (out , Sequence ) or len (out ) != 1 : # pyright: ignore[reportUnnecessaryIsInstance]
342
- msg = (
343
- "apply_numpy_func: func must return a single numpy object or a "
344
- f"sequence with exactly one numpy object; got { out } "
345
- )
346
- raise ValueError (msg )
347
- out = out [0 ]
348
-
349
- return xp .asarray (out )
350
-
351
- return wrapper
0 commit comments