Skip to content

Commit bf67bb8

Browse files
committed
MAINT: Array API 2024.12 typing nits
1 parent 75e5166 commit bf67bb8

File tree

4 files changed

+52
-28
lines changed

4 files changed

+52
-28
lines changed

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ increase performance.
185185
In particular, the following kinds of function are also in-scope:
186186

187187
- Functions which implement
188-
[array API standard extension](https://data-apis.org/array-api/2023.12/extensions/index.html)
188+
[array API standard extension](https://data-apis.org/array-api/latest/extensions/index.html)
189189
functions in terms of functions from the base standard.
190190
- Functions which add functionality (e.g. extra parameters) to functions from
191191
the standard.

src/array_api_extra/_lib/_funcs.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections.abc import Sequence
99
from types import ModuleType
10-
from typing import cast
10+
from typing import TYPE_CHECKING, cast
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
@@ -375,8 +375,8 @@ def expand_dims(
375375

376376

377377
def isclose(
378-
a: Array,
379-
b: Array,
378+
a: Array | complex,
379+
b: Array | complex,
380380
*,
381381
rtol: float = 1e-05,
382382
atol: float = 1e-08,
@@ -385,6 +385,9 @@ def isclose(
385385
) -> Array: # numpydoc ignore=PR01,RT01
386386
"""See docstring in array_api_extra._delegation."""
387387
a, b = asarrays(a, b, xp=xp)
388+
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
389+
assert _compat.is_array_api_obj(a)
390+
assert _compat.is_array_api_obj(b)
388391

389392
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
390393
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -419,7 +422,13 @@ def isclose(
419422
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
420423

421424

422-
def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
425+
def kron(
426+
a: Array | complex,
427+
b: Array | complex,
428+
/,
429+
*,
430+
xp: ModuleType | None = None,
431+
) -> Array:
423432
"""
424433
Kronecker product of two arrays.
425434
@@ -495,9 +504,14 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
495504
if xp is None:
496505
xp = array_namespace(a, b)
497506
a, b = asarrays(a, b, xp=xp)
507+
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
508+
assert _compat.is_array_api_obj(a)
509+
assert _compat.is_array_api_obj(b)
498510

499511
singletons = (1,) * (b.ndim - a.ndim)
500512
a = xp.broadcast_to(a, singletons + a.shape)
513+
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
514+
assert _compat.is_array_api_obj(a)
501515

502516
nd_b, nd_a = b.ndim, a.ndim
503517
nd_max = max(nd_b, nd_a)
@@ -614,8 +628,8 @@ def pad(
614628

615629

616630
def setdiff1d(
617-
x1: Array,
618-
x2: Array,
631+
x1: Array | complex,
632+
x2: Array | complex,
619633
/,
620634
*,
621635
assume_unique: bool = False,
@@ -665,6 +679,10 @@ def setdiff1d(
665679
else:
666680
x1 = xp.unique_values(x1)
667681
x2 = xp.unique_values(x2)
682+
683+
if TYPE_CHECKING: # Hack around pyright bug # pragma: no cover
684+
assert _compat.is_array_api_obj(x1)
685+
668686
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
669687

670688

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,30 @@ from __future__ import annotations
55

66
from types import ModuleType
77

8+
# TODO import from typing (requires Python >=3.13)
9+
from typing_extensions import TypeIs
10+
811
from ._typing import Array, Device
912

1013
# pylint: disable=missing-class-docstring,unused-argument
1114

12-
class ArrayModule(ModuleType):
15+
class Namespace(ModuleType):
1316
def device(self, x: Array, /) -> Device: ...
1417

1518
def array_namespace(
16-
*xs: Array,
19+
*xs: Array | complex | None,
1720
api_version: str | None = None,
1821
use_compat: bool | None = None,
19-
) -> ArrayModule: ...
22+
) -> Namespace: ...
2023
def device(x: Array, /) -> Device: ...
21-
def is_array_api_obj(x: object, /) -> bool: ...
22-
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
23-
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
24-
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
25-
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
26-
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
27-
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
28-
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
24+
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
25+
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
26+
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
27+
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
28+
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
29+
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
30+
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
31+
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
2932
def is_cupy_array(x: object, /) -> bool: ...
3033
def is_dask_array(x: object, /) -> bool: ...
3134
def is_jax_array(x: object, /) -> bool: ...

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55

66
from collections.abc import Generator
77
from types import ModuleType
8-
from typing import cast
8+
from typing import TYPE_CHECKING
99

1010
from . import _compat
1111
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
1212
from ._typing import Array
1313

14+
if TYPE_CHECKING: # pragma: no cover
15+
# TODO import from typing (requires Python >=3.13)
16+
from typing_extensions import TypeIs
17+
18+
1419
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]
1520

1621

@@ -96,16 +101,16 @@ def mean(
96101
return xp.mean(x, axis=axis, keepdims=keepdims)
97102

98103

99-
def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
104+
def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01
100105
"""Return True if `x` is a Python scalar, False otherwise."""
101106
# isinstance(x, float) returns True for np.float64
102107
# isinstance(x, complex) returns True for np.complex128
103-
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
108+
return isinstance(x, int | float | complex) and not is_numpy_array(x)
104109

105110

106111
def asarrays(
107-
a: Array | int | float | complex | bool,
108-
b: Array | int | float | complex | bool,
112+
a: Array | complex,
113+
b: Array | complex,
109114
xp: ModuleType,
110115
) -> tuple[Array, Array]:
111116
"""
@@ -150,9 +155,7 @@ def asarrays(
150155
if is_array_api_obj(a):
151156
# a is an Array API object
152157
# b is a int | float | complex | bool
153-
154-
# pyright doesn't like it if you reuse the same variable name
155-
xa = cast(Array, a)
158+
xa = a
156159

157160
# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
158161
same_dtype = {
@@ -162,8 +165,8 @@ def asarrays(
162165
complex: "complex floating",
163166
}
164167
kind = same_dtype[type(b)] # type: ignore[index]
165-
if xp.isdtype(xa.dtype, kind):
166-
xb = xp.asarray(b, dtype=xa.dtype)
168+
if xp.isdtype(a.dtype, kind):
169+
xb = xp.asarray(b, dtype=a.dtype)
167170
else:
168171
# Undefined behaviour. Let the function deal with it, if it can.
169172
xb = xp.asarray(b)

0 commit comments

Comments
 (0)