Skip to content

MAINT: Array API 2024.12 typing nits #156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ increase performance.
In particular, the following kinds of function are also in-scope:

- Functions which implement
[array API standard extension](https://data-apis.org/array-api/2023.12/extensions/index.html)
[array API standard extension](https://data-apis.org/array-api/latest/extensions/index.html)
functions in terms of functions from the base standard.
- Functions which add functionality (e.g. extra parameters) to functions from
the standard.
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _delegate(xp: ModuleType, *backends: Backend) -> bool:


def isclose(
a: Array,
b: Array,
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
Expand Down
36 changes: 29 additions & 7 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections.abc import Sequence
from types import ModuleType
from typing import cast
from typing import TYPE_CHECKING, cast

from ._at import at
from ._utils import _compat, _helpers
Expand Down Expand Up @@ -375,8 +375,8 @@ def expand_dims(


def isclose(
a: Array,
b: Array,
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
Expand All @@ -385,6 +385,10 @@ def isclose(
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -419,7 +423,13 @@ def isclose(
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)


def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
def kron(
a: Array | complex,
b: Array | complex,
/,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Kronecker product of two arrays.

Expand Down Expand Up @@ -495,9 +505,16 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)
assert _compat.is_array_api_obj(b)

singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(a, singletons + a.shape)
# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(a)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
Expand Down Expand Up @@ -614,8 +631,8 @@ def pad(


def setdiff1d(
x1: Array,
x2: Array,
x1: Array | complex,
x2: Array | complex,
/,
*,
assume_unique: bool = False,
Expand All @@ -628,7 +645,7 @@ def setdiff1d(

Parameters
----------
x1 : array
x1 : array | int | float | complex | bool
Input array.
x2 : array
Input comparison array.
Expand Down Expand Up @@ -665,6 +682,11 @@ def setdiff1d(
else:
x1 = xp.unique_values(x1)
x2 = xp.unique_values(x2)

# FIXME https://github.com/microsoft/pyright/issues/10085
if TYPE_CHECKING: # pragma: nocover
assert _compat.is_array_api_obj(x1)

return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]


Expand Down
25 changes: 14 additions & 11 deletions src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,30 @@ from __future__ import annotations

from types import ModuleType

# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs

from ._typing import Array, Device

# pylint: disable=missing-class-docstring,unused-argument

class ArrayModule(ModuleType):
class Namespace(ModuleType):
def device(self, x: Array, /) -> Device: ...

def array_namespace(
*xs: Array,
*xs: Array | complex | None,
api_version: str | None = None,
use_compat: bool | None = None,
) -> ArrayModule: ...
) -> Namespace: ...
def device(x: Array, /) -> Device: ...
def is_array_api_obj(x: object, /) -> bool: ...
def is_array_api_strict_namespace(xp: ModuleType, /) -> bool: ...
def is_cupy_namespace(xp: ModuleType, /) -> bool: ...
def is_dask_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_namespace(xp: ModuleType, /) -> bool: ...
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_array_api_obj(x: object, /) -> TypeIs[Array]: ...
def is_array_api_strict_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_cupy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_dask_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_jax_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_numpy_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_torch_namespace(xp: ModuleType, /) -> TypeIs[Namespace]: ...
def is_cupy_array(x: object, /) -> bool: ...
def is_dask_array(x: object, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
Expand Down
24 changes: 14 additions & 10 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from collections.abc import Generator
from types import ModuleType
from typing import cast
from typing import TYPE_CHECKING

from . import _compat
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
from ._typing import Array

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs


__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]


Expand Down Expand Up @@ -96,16 +101,17 @@ def mean(
return xp.mean(x, axis=axis, keepdims=keepdims)


def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01
"""Return True if `x` is a Python scalar, False otherwise."""
# isinstance(x, float) returns True for np.float64
# isinstance(x, complex) returns True for np.complex128
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
# bool is a subclass of int
return isinstance(x, int | float | complex) and not is_numpy_array(x)


def asarrays(
a: Array | int | float | complex | bool,
b: Array | int | float | complex | bool,
a: Array | complex,
b: Array | complex,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Expand Down Expand Up @@ -150,9 +156,7 @@ def asarrays(
if is_array_api_obj(a):
# a is an Array API object
# b is a int | float | complex | bool

# pyright doesn't like it if you reuse the same variable name
xa = cast(Array, a)
xa = a

# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
same_dtype = {
Expand All @@ -162,8 +166,8 @@ def asarrays(
complex: "complex floating",
}
kind = same_dtype[type(b)] # type: ignore[index]
if xp.isdtype(xa.dtype, kind):
xb = xp.asarray(b, dtype=xa.dtype)
if xp.isdtype(a.dtype, kind):
xb = xp.asarray(b, dtype=a.dtype)
else:
# Undefined behaviour. Let the function deal with it, if it can.
xb = xp.asarray(b)
Expand Down