diff --git a/docs/index.md b/docs/index.md index f7c51574..ae15c7f4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -185,7 +185,7 @@ increase performance. In particular, the following kinds of function are also in-scope: - Functions which implement - [array API standard extension](https://data-apis.org/array-api/2023.12/extensions/index.html) + [array API standard extension](https://data-apis.org/array-api/latest/extensions/index.html) functions in terms of functions from the base standard. - Functions which add functionality (e.g. extra parameters) to functions from the standard. diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index f3295c45..0c455ae1 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -31,8 +31,8 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool: def isclose( - a: Array, - b: Array, + a: Array | complex, + b: Array | complex, *, rtol: float = 1e-05, atol: float = 1e-08, diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 5621017a..cf06dd55 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -7,7 +7,7 @@ import warnings from collections.abc import Sequence from types import ModuleType -from typing import cast +from typing import TYPE_CHECKING, cast from ._at import at from ._utils import _compat, _helpers @@ -375,8 +375,8 @@ def expand_dims( def isclose( - a: Array, - b: Array, + a: Array | complex, + b: Array | complex, *, rtol: float = 1e-05, atol: float = 1e-08, @@ -385,6 +385,10 @@ def isclose( ) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" a, b = asarrays(a, b, xp=xp) + # FIXME https://github.com/microsoft/pyright/issues/10085 + if TYPE_CHECKING: # pragma: nocover + assert _compat.is_array_api_obj(a) + assert _compat.is_array_api_obj(b) a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating")) b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating")) @@ -419,7 +423,13 @@ def isclose( return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol) -def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: +def kron( + a: Array | complex, + b: Array | complex, + /, + *, + xp: ModuleType | None = None, +) -> Array: """ Kronecker product of two arrays. @@ -495,9 +505,16 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: if xp is None: xp = array_namespace(a, b) a, b = asarrays(a, b, xp=xp) + # FIXME https://github.com/microsoft/pyright/issues/10085 + if TYPE_CHECKING: # pragma: nocover + assert _compat.is_array_api_obj(a) + assert _compat.is_array_api_obj(b) singletons = (1,) * (b.ndim - a.ndim) a = xp.broadcast_to(a, singletons + a.shape) + # FIXME https://github.com/microsoft/pyright/issues/10085 + if TYPE_CHECKING: # pragma: nocover + assert _compat.is_array_api_obj(a) nd_b, nd_a = b.ndim, a.ndim nd_max = max(nd_b, nd_a) @@ -614,8 +631,8 @@ def pad( def setdiff1d( - x1: Array, - x2: Array, + x1: Array | complex, + x2: Array | complex, /, *, assume_unique: bool = False, @@ -628,7 +645,7 @@ def setdiff1d( Parameters ---------- - x1 : array + x1 : array | int | float | complex | bool Input array. x2 : array Input comparison array. @@ -665,6 +682,11 @@ def setdiff1d( else: x1 = xp.unique_values(x1) x2 = xp.unique_values(x2) + + # FIXME https://github.com/microsoft/pyright/issues/10085 + if TYPE_CHECKING: # pragma: nocover + assert _compat.is_array_api_obj(x1) + return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index 1f585a38..66134fae 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -5,27 +5,30 @@ from __future__ import annotations from types import ModuleType +# TODO import from typing (requires Python >=3.13) +from typing_extensions import TypeIs + from ._typing import Array, Device # pylint: disable=missing-class-docstring,unused-argument -class ArrayModule(ModuleType): +class Namespace(ModuleType): def device(self, x: Array, /) -> Device: ... def array_namespace( - *xs: Array, + *xs: Array | complex | None, api_version: str | None = None, use_compat: bool | None = None, -) -> ArrayModule: ... +) -> Namespace: ... def device(x: Array, /) -> Device: ... -def is_array_api_obj(x: object, /) -> bool: ... -def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ... -def is_cupy_namespace(xp: ModuleType, /) -> bool: ... -def is_dask_namespace(xp: ModuleType, /) -> bool: ... -def is_jax_namespace(xp: ModuleType, /) -> bool: ... -def is_numpy_namespace(xp: ModuleType, /) -> bool: ... -def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ... -def is_torch_namespace(xp: ModuleType, /) -> bool: ... +def is_array_api_obj(x: object, /) -> TypeIs[Array]: ... +def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... +def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ... def is_cupy_array(x: object, /) -> bool: ... def is_dask_array(x: object, /) -> bool: ... def is_jax_array(x: object, /) -> bool: ... diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index e8419f24..594b6e12 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -5,12 +5,17 @@ from collections.abc import Generator from types import ModuleType -from typing import cast +from typing import TYPE_CHECKING from . import _compat from ._compat import array_namespace, is_array_api_obj, is_numpy_array from ._typing import Array +if TYPE_CHECKING: # pragma: no cover + # TODO import from typing (requires Python >=3.13) + from typing_extensions import TypeIs + + __all__ = ["asarrays", "in1d", "is_python_scalar", "mean"] @@ -96,16 +101,17 @@ def mean( return xp.mean(x, axis=axis, keepdims=keepdims) -def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01 +def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01 """Return True if `x` is a Python scalar, False otherwise.""" # isinstance(x, float) returns True for np.float64 # isinstance(x, complex) returns True for np.complex128 - return isinstance(x, int | float | complex | bool) and not is_numpy_array(x) + # bool is a subclass of int + return isinstance(x, int | float | complex) and not is_numpy_array(x) def asarrays( - a: Array | int | float | complex | bool, - b: Array | int | float | complex | bool, + a: Array | complex, + b: Array | complex, xp: ModuleType, ) -> tuple[Array, Array]: """ @@ -150,9 +156,7 @@ def asarrays( if is_array_api_obj(a): # a is an Array API object # b is a int | float | complex | bool - - # pyright doesn't like it if you reuse the same variable name - xa = cast(Array, a) + xa = a # https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars same_dtype = { @@ -162,8 +166,8 @@ def asarrays( complex: "complex floating", } kind = same_dtype[type(b)] # type: ignore[index] - if xp.isdtype(xa.dtype, kind): - xb = xp.asarray(b, dtype=xa.dtype) + if xp.isdtype(a.dtype, kind): + xb = xp.asarray(b, dtype=a.dtype) else: # Undefined behaviour. Let the function deal with it, if it can. xb = xp.asarray(b)