Skip to content

Commit fa47560

Browse files
committed
overload
1 parent 0405ff5 commit fa47560

File tree

1 file changed

+91
-47
lines changed

1 file changed

+91
-47
lines changed

src/array_api_extra/_apply.py

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Callable, Sequence
77
from functools import wraps
88
from types import ModuleType
9-
from typing import TYPE_CHECKING, Any, cast
9+
from typing import TYPE_CHECKING, Any, cast, overload
1010

1111
from ._lib._compat import (
1212
array_namespace,
@@ -22,16 +22,39 @@
2222
import numpy as np
2323

2424
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit]
25+
KwArg: TypeAlias = Any # type: ignore[no-any-explicit]
26+
27+
28+
@overload
29+
def apply_numpy_func(
30+
func: Callable[..., NumPyObject],
31+
*args: Array,
32+
shape: tuple[int, ...] | None = None,
33+
dtype: DType | None = None,
34+
xp: ModuleType | None = None,
35+
**kwargs: KwArg,
36+
) -> Array: ... # numpydoc ignore=GL08
37+
38+
39+
@overload
40+
def apply_numpy_func( # type: ignore[no-any-decorated]
41+
func: Callable[..., Sequence[NumPyObject]],
42+
*args: Array,
43+
shape: Sequence[tuple[int, ...]],
44+
dtype: Sequence[DType] | None = None,
45+
xp: ModuleType | None = None,
46+
**kwargs: Any,
47+
) -> tuple[Array, ...]: ... # numpydoc ignore=GL08
2548

2649

2750
def apply_numpy_func( # type: ignore[no-any-explicit]
2851
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
2952
*args: Array,
30-
shapes: Sequence[tuple[int, ...]] | None = None,
31-
dtypes: Sequence[DType] | None = None,
53+
shape: tuple[int, ...] | Sequence[tuple[int, ...]] | None = None,
54+
dtype: DType | Sequence[DType] | None = None,
3255
xp: ModuleType | None = None,
3356
**kwargs: Any,
34-
) -> tuple[Array, ...]:
57+
) -> Array | tuple[Array, ...]:
3558
"""
3659
Apply a function that operates on NumPy arrays to Array API compliant arrays.
3760
@@ -48,15 +71,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
4871
One or more Array API compliant arrays. You need to be able to apply
4972
``np.asarray()`` to them to convert them to numpy; read notes below about
5073
specific backends.
51-
shapes : Sequence[tuple[int, ...]], optional
52-
Sequence of output shapes, one for each output of `func`.
53-
If `func` returns a single (non-sequence) output, this must be a sequence
54-
with a single element.
55-
Default: assume a single output and broadcast shapes of the input arrays.
56-
dtypes : Sequence[DType], optional
57-
Sequence of output dtypes, one for each output of `func`.
58-
If `func` returns a single (non-sequence) output, this must be a sequence
59-
with a single element.
74+
shape : tuple[int, ...] | Sequence[tuple[int, ...]], optional
75+
Output shape or sequence of output shapes, one for each output of `func`.
76+
Default: assume single output and broadcast shapes of the input arrays.
77+
dtype : DType | Sequence[DType], optional
78+
Output dtype or sequence of output dtypes, one for each output of `func`.
6079
Default: infer the result type(s) from the input arrays.
6180
xp : array_namespace, optional
6281
The standard-compatible namespace for `args`. Default: infer.
@@ -66,9 +85,11 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
6685
6786
Returns
6887
-------
69-
tuple[Array, ...]
70-
The result(s) of `func` applied to the input arrays.
71-
This is always a tuple, even if `func` returns a single output.
88+
Array | tuple[Array, ...]
89+
The result(s) of `func` applied to the input arrays, wrapped in the same
90+
array namespace as the inputs.
91+
If shape is omitted or a `tuple[int, ...]`, this is a single array.
92+
Otherwise, it's a tuple of arrays.
7293
7394
Notes
7495
-----
@@ -110,46 +131,67 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
110131
"""
111132
if xp is None:
112133
xp = array_namespace(*args)
113-
if shapes is None:
134+
135+
# Normalize and validate shape and dtype
136+
multi_output = False
137+
if shape is None:
114138
shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))]
115-
if dtypes is None:
139+
elif isinstance(shape, tuple) and all(isinstance(s, int) for s in shape):
140+
shapes = [shape]
141+
else:
142+
shapes = shape
143+
multi_output = True
144+
145+
if dtype is None:
116146
dtypes = [xp.result_type(*args)] * len(shapes)
147+
elif multi_output:
148+
if not isinstance(dtype, Sequence):
149+
msg = "Got sequence of shapes but only one dtype"
150+
raise TypeError(msg)
151+
dtypes = dtype
152+
else:
153+
if isinstance(dtype, Sequence):
154+
msg = "Got single shape but multiple dtypes"
155+
raise TypeError(msg)
156+
dtypes = [dtype]
117157

118158
if len(shapes) != len(dtypes):
119-
msg = f"got {len(shapes)} shapes and {len(dtypes)} dtypes"
159+
msg = f"Got {len(shapes)} shapes and {len(dtypes)} dtypes"
120160
raise ValueError(msg)
121161
if len(shapes) == 0:
122-
msg = "Must have at least one output array"
162+
msg = "func must return one or more output arrays"
123163
raise ValueError(msg)
164+
del shape
165+
del dtype
124166

167+
# Backend-specific branches
125168
if is_dask_namespace(xp):
126169
import dask # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
127170

128171
metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access
129172
meta_xp = array_namespace(*metas)
130-
meta = metas[0]
131173

132-
wrapped = dask.delayed(_npfunc_wrapper(func, meta_xp), pure=True)
174+
wrapped = dask.delayed(_npfunc_wrapper(func, multi_output, meta_xp), pure=True)
133175
# This finalizes each arg, which is the same as arg.rechunk(-1)
134176
# Please read docstring above for why we're not using
135177
# dask.array.map_blocks or dask.array.blockwise!
136178
delayed_out = wrapped(*args, **kwargs)
137179

138-
return tuple(
139-
xp.from_delayed(delayed_out[i], shape=shape, dtype=dtype, meta=meta)
180+
out = tuple(
181+
xp.from_delayed(delayed_out[i], shape=shape, dtype=dtype, meta=metas[0])
140182
for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True))
141183
)
142184

143-
wrapped = _npfunc_wrapper(func, xp)
144-
if is_jax_namespace(xp):
185+
elif is_jax_namespace(xp):
145186
# If we're inside jax.jit, we can't eagerly convert
146187
# the JAX tracer objects to numpy.
147188
# Instead, we delay calling wrapped, which will receive
148189
# as arguments and will return JAX eager arrays.
149190

150191
import jax # type: ignore[import-not-found] # pylint: disable=import-outside-toplevel,import-error # pyright: ignore[reportMissingImports]
151192

152-
return cast(
193+
wrapped = _npfunc_wrapper(func, multi_output, xp)
194+
out = cast(
153195
tuple[Array, ...],
154196
jax.pure_callback(
155197
wrapped,
@@ -162,25 +204,29 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
162204
),
163205
)
164206

165-
# Eager backends
166-
out = wrapped(*args, **kwargs)
207+
else:
208+
# Eager backends
209+
wrapped = _npfunc_wrapper(func, multi_output, xp)
210+
out = wrapped(*args, **kwargs)
167211

168-
# Output validation
169-
if len(out) != len(shapes):
170-
msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}"
171-
raise ValueError(msg)
172-
for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True):
173-
if out_i.shape != shape_i:
174-
msg = f"expected shape {shape_i}, got {out_i.shape}"
175-
raise ValueError(msg)
176-
if not xp.isdtype(out_i.dtype, dtype_i):
177-
msg = f"expected dtype {dtype_i}, got {out_i.dtype}"
212+
# Output validation
213+
if len(out) != len(shapes):
214+
msg = f"func was declared to return {len(shapes)} outputs, got {len(out)}"
178215
raise ValueError(msg)
179-
return out # type: ignore[no-any-return]
216+
for out_i, shape_i, dtype_i in zip(out, shapes, dtypes, strict=True):
217+
if out_i.shape != shape_i:
218+
msg = f"expected shape {shape_i}, got {out_i.shape}"
219+
raise ValueError(msg)
220+
if not xp.isdtype(out_i.dtype, dtype_i):
221+
msg = f"expected dtype {dtype_i}, got {out_i.dtype}"
222+
raise ValueError(msg)
223+
224+
return out if multi_output else out[0]
180225

181226

182227
def _npfunc_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01
183228
func: Callable[..., NumPyObject | Sequence[NumPyObject]],
229+
multi_output: bool,
184230
xp: ModuleType,
185231
) -> Callable[..., tuple[Array, ...]]:
186232
"""
@@ -208,14 +254,12 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
208254
args = tuple(np.asarray(arg) for arg in args)
209255
out = func(*args, **kwargs)
210256

211-
if isinstance(out, np.ndarray | np.generic):
257+
if multi_output:
258+
if not isinstance(out, Sequence) or isinstance(out, np.ndarray):
259+
msg = "Expected multiple outputs, got a single one"
260+
raise ValueError(msg)
261+
else:
212262
out = (out,)
213-
elif not isinstance(out, Sequence): # pyright: ignore[reportUnnecessaryIsInstance]
214-
msg = (
215-
"apply_numpy_func: func must return a numpy object or a "
216-
f"sequence of numpy objects; got {out}"
217-
)
218-
raise TypeError(msg)
219263

220264
return tuple(xp.asarray(o) for o in out)
221265

0 commit comments

Comments
 (0)