6
6
from collections .abc import Callable , Sequence
7
7
from functools import wraps
8
8
from types import ModuleType
9
- from typing import TYPE_CHECKING , Any , cast
9
+ from typing import TYPE_CHECKING , Any , cast , overload
10
10
11
11
from ._lib ._compat import (
12
12
array_namespace ,
22
22
import numpy as np
23
23
24
24
NumPyObject : TypeAlias = np .ndarray [Any , Any ] | np .generic # type: ignore[no-any-explicit]
25
+ KwArg : TypeAlias = Any # type: ignore[no-any-explicit]
26
+
27
+
28
+ @overload
29
+ def apply_numpy_func (
30
+ func : Callable [..., NumPyObject ],
31
+ * args : Array ,
32
+ shape : tuple [int , ...] | None = None ,
33
+ dtype : DType | None = None ,
34
+ xp : ModuleType | None = None ,
35
+ ** kwargs : KwArg ,
36
+ ) -> Array : ... # numpydoc ignore=GL08
37
+
38
+
39
+ @overload
40
+ def apply_numpy_func ( # type: ignore[no-any-decorated]
41
+ func : Callable [..., Sequence [NumPyObject ]],
42
+ * args : Array ,
43
+ shape : Sequence [tuple [int , ...]],
44
+ dtype : Sequence [DType ] | None = None ,
45
+ xp : ModuleType | None = None ,
46
+ ** kwargs : Any ,
47
+ ) -> tuple [Array , ...]: ... # numpydoc ignore=GL08
25
48
26
49
27
50
def apply_numpy_func ( # type: ignore[no-any-explicit]
28
51
func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
29
52
* args : Array ,
30
- shapes : Sequence [tuple [int , ...]] | None = None ,
31
- dtypes : Sequence [DType ] | None = None ,
53
+ shape : tuple [ int , ...] | Sequence [tuple [int , ...]] | None = None ,
54
+ dtype : DType | Sequence [DType ] | None = None ,
32
55
xp : ModuleType | None = None ,
33
56
** kwargs : Any ,
34
- ) -> tuple [Array , ...]:
57
+ ) -> Array | tuple [Array , ...]:
35
58
"""
36
59
Apply a function that operates on NumPy arrays to Array API compliant arrays.
37
60
@@ -48,15 +71,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
48
71
One or more Array API compliant arrays. You need to be able to apply
49
72
``np.asarray()`` to them to convert them to numpy; read notes below about
50
73
specific backends.
51
- shapes : Sequence[tuple[int, ...]], optional
52
- Sequence of output shapes, one for each output of `func`.
53
- If `func` returns a single (non-sequence) output, this must be a sequence
54
- with a single element.
55
- Default: assume a single output and broadcast shapes of the input arrays.
56
- dtypes : Sequence[DType], optional
57
- Sequence of output dtypes, one for each output of `func`.
58
- If `func` returns a single (non-sequence) output, this must be a sequence
59
- with a single element.
74
+ shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
75
+ Output shape or sequence of output shapes, one for each output of `func`.
76
+ Default: assume single output and broadcast shapes of the input arrays.
77
+ dtype : DType | Sequence[DType], optional
78
+ Output dtype or sequence of output dtypes, one for each output of `func`.
60
79
Default: infer the result type(s) from the input arrays.
61
80
xp : array_namespace, optional
62
81
The standard-compatible namespace for `args`. Default: infer.
@@ -66,9 +85,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
66
85
67
86
Returns
68
87
-------
69
- tuple[Array, ...]
70
- The result(s) of `func` applied to the input arrays.
71
- This is always a tuple, even if `func` returns a single output.
88
+ Array | tuple[Array, ...]
89
+ The result(s) of `func` applied to the input arrays, wrapped in the same
90
+ array namespace as the inputs.
91
+ If shape is omitted or a `tuple[int, ...]`, this is a single array.
92
+ Otherwise, it's a tuple of arrays.
72
93
73
94
Notes
74
95
-----
@@ -110,46 +131,67 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
110
131
"""
111
132
if xp is None :
112
133
xp = array_namespace (* args )
113
- if shapes is None :
134
+
135
+ # Normalize and validate shape and dtype
136
+ multi_output = False
137
+ if shape is None :
114
138
shapes = [xp .broadcast_shapes (* (arg .shape for arg in args ))]
115
- if dtypes is None :
139
+ elif isinstance (shape , tuple ) and all (isinstance (s , int ) for s in shape ):
140
+ shapes = [shape ]
141
+ else :
142
+ shapes = shape
143
+ multi_output = True
144
+
145
+ if dtype is None :
116
146
dtypes = [xp .result_type (* args )] * len (shapes )
147
+ elif multi_output :
148
+ if not isinstance (dtype , Sequence ):
149
+ msg = "Got sequence of shapes but only one dtype"
150
+ raise TypeError (msg )
151
+ dtypes = dtype
152
+ else :
153
+ if isinstance (dtype , Sequence ):
154
+ msg = "Got single shape but multiple dtypes"
155
+ raise TypeError (msg )
156
+ dtypes = [dtype ]
117
157
118
158
if len (shapes ) != len (dtypes ):
119
- msg = f"got { len (shapes )} shapes and { len (dtypes )} dtypes"
159
+ msg = f"Got { len (shapes )} shapes and { len (dtypes )} dtypes"
120
160
raise ValueError (msg )
121
161
if len (shapes ) == 0 :
122
- msg = "Must have at least one output array "
162
+ msg = "func must return one or more output arrays "
123
163
raise ValueError (msg )
164
+ del shape
165
+ del dtype
124
166
167
+ # Backend-specific branches
125
168
if is_dask_namespace (xp ):
126
169
import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127
170
128
171
metas = [arg ._meta for arg in args if hasattr (arg , "_meta" )] # pylint: disable=protected-access
129
172
meta_xp = array_namespace (* metas )
130
- meta = metas [0 ]
131
173
132
- wrapped = dask .delayed (_npfunc_wrapper (func , meta_xp ), pure = True )
174
+ wrapped = dask .delayed (_npfunc_wrapper (func , multi_output , meta_xp ), pure = True )
133
175
# This finalizes each arg, which is the same as arg.rechunk(-1)
134
176
# Please read docstring above for why we're not using
135
177
# dask.array.map_blocks or dask.array.blockwise!
136
178
delayed_out = wrapped (* args , ** kwargs )
137
179
138
- return tuple (
139
- xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = meta )
180
+ out = tuple (
181
+ xp .from_delayed (delayed_out [i ], shape = shape , dtype = dtype , meta = metas [ 0 ] )
140
182
for i , (shape , dtype ) in enumerate (zip (shapes , dtypes , strict = True ))
141
183
)
142
184
143
- wrapped = _npfunc_wrapper (func , xp )
144
- if is_jax_namespace (xp ):
185
+ elif is_jax_namespace (xp ):
145
186
# If we're inside jax.jit, we can't eagerly convert
146
187
# the JAX tracer objects to numpy.
147
188
# Instead, we delay calling wrapped, which will receive
148
189
# as arguments and will return JAX eager arrays.
149
190
150
191
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
151
192
152
- return cast (
193
+ wrapped = _npfunc_wrapper (func , multi_output , xp )
194
+ out = cast (
153
195
tuple [Array , ...],
154
196
jax .pure_callback (
155
197
wrapped ,
@@ -162,25 +204,29 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
162
204
),
163
205
)
164
206
165
- # Eager backends
166
- out = wrapped (* args , ** kwargs )
207
+ else :
208
+ # Eager backends
209
+ wrapped = _npfunc_wrapper (func , multi_output , xp )
210
+ out = wrapped (* args , ** kwargs )
167
211
168
- # Output validation
169
- if len (out ) != len (shapes ):
170
- msg = f"func was declared to return { len (shapes )} outputs, got { len (out )} "
171
- raise ValueError (msg )
172
- for out_i , shape_i , dtype_i in zip (out , shapes , dtypes , strict = True ):
173
- if out_i .shape != shape_i :
174
- msg = f"expected shape { shape_i } , got { out_i .shape } "
175
- raise ValueError (msg )
176
- if not xp .isdtype (out_i .dtype , dtype_i ):
177
- msg = f"expected dtype { dtype_i } , got { out_i .dtype } "
212
+ # Output validation
213
+ if len (out ) != len (shapes ):
214
+ msg = f"func was declared to return { len (shapes )} outputs, got { len (out )} "
178
215
raise ValueError (msg )
179
- return out # type: ignore[no-any-return]
216
+ for out_i , shape_i , dtype_i in zip (out , shapes , dtypes , strict = True ):
217
+ if out_i .shape != shape_i :
218
+ msg = f"expected shape { shape_i } , got { out_i .shape } "
219
+ raise ValueError (msg )
220
+ if not xp .isdtype (out_i .dtype , dtype_i ):
221
+ msg = f"expected dtype { dtype_i } , got { out_i .dtype } "
222
+ raise ValueError (msg )
223
+
224
+ return out if multi_output else out [0 ]
180
225
181
226
182
227
def _npfunc_wrapper ( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
183
228
func : Callable [..., NumPyObject | Sequence [NumPyObject ]],
229
+ multi_output : bool ,
184
230
xp : ModuleType ,
185
231
) -> Callable [..., tuple [Array , ...]]:
186
232
"""
@@ -208,14 +254,12 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
208
254
args = tuple (np .asarray (arg ) for arg in args )
209
255
out = func (* args , ** kwargs )
210
256
211
- if isinstance (out , np .ndarray | np .generic ):
257
+ if multi_output :
258
+ if not isinstance (out , Sequence ) or isinstance (out , np .ndarray ):
259
+ msg = "Expected multiple outputs, got a single one"
260
+ raise ValueError (msg )
261
+ else :
212
262
out = (out ,)
213
- elif not isinstance (out , Sequence ): # pyright: ignore[reportUnnecessaryIsInstance]
214
- msg = (
215
- "apply_numpy_func: func must return a numpy object or a "
216
- f"sequence of numpy objects; got { out } "
217
- )
218
- raise TypeError (msg )
219
263
220
264
return tuple (xp .asarray (o ) for o in out )
221
265
0 commit comments