Skip to content

TYP: Type annotations, part 4 #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
specification for more details.

"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]

return inner

Expand Down
13 changes: 7 additions & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._helpers import is_cupy_namespace
from ._typing import Array, Device, DType, Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -381,8 +382,8 @@ def clip(
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
def _isscalar(a: object) -> TypeIs[float | None]:
return isinstance(a, int | float) or a is None

min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
Expand Down Expand Up @@ -450,7 +451,7 @@ def reshape(
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
copy: bool | None = None,
**kwargs: object,
) -> Array:
if copy is True:
Expand Down Expand Up @@ -657,7 +658,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

Expand Down
34 changes: 14 additions & 20 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,36 @@
SupportsIndex,
TypeAlias,
TypeGuard,
TypeVar,
cast,
overload,
)

from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace

if TYPE_CHECKING:

import cupy as cp
import dask.array as da
import jax
import ndonnx as ndx
import numpy as np
import numpy.typing as npt
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
import torch

# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs, TypeVar

_SizeT = TypeVar("_SizeT", bound = int | None)
from typing_extensions import TypeIs

_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
_CupyArray: TypeAlias = Any # cupy has no py.typed

_ArrayApiObj: TypeAlias = (
npt.NDArray[Any]
| cp.ndarray
| da.Array
| jax.Array
| ndx.Array
| sparse.SparseArray
| torch.Tensor
| SupportsArrayNamespace[Any]
| _CupyArray
)

_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
Expand Down Expand Up @@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
return dtype == jax.float0


def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
"""
Return True if `x` is a NumPy array.

Expand Down Expand Up @@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
return _issubclass_fast(cls, "sparse", "SparseArray")


def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return True if `x` is an array API compatible array object.

Expand Down Expand Up @@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta): # pyright: ignore
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
Expand Down Expand Up @@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner) # pyright: ignore
return x.device # pyright: ignore
return x.device # type: ignore # pyright: ignore


# Prevent shadowing, used below
Expand All @@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device:

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(
x: _CupyArray,
x: cp.ndarray,
device: Device,
/,
stream: int | Any | None = None,
) -> _CupyArray:
) -> cp.ndarray:
import cupy as cp

if device == "cpu":
Expand Down Expand Up @@ -819,7 +815,7 @@ def _torch_to_device(
x: torch.Tensor,
device: torch.device | str | int,
/,
stream: None = None,
stream: int | Any | None = None,
) -> torch.Tensor:
if stream is not None:
raise NotImplementedError
Expand Down Expand Up @@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
Expand All @@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
@overload
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
@overload
def size(x: HasShape[Collection[None]]) -> None: ...
@overload
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
"""
Expand Down Expand Up @@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
return None


def is_writeable_array(x: object) -> bool:
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Expand Down Expand Up @@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
return None


def is_lazy_array(x: object) -> bool:
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down
6 changes: 3 additions & 3 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
Expand Down Expand Up @@ -187,14 +187,14 @@ def vector_norm(
# We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks
# above to avoid matrix norm logic.
shape = list(x.shape)
_axis = cast(
axes = cast(
"tuple[int, ...]",
normalize_axis_tuple( # pyright: ignore[reportCallIssue]
range(x.ndim) if axis is None else axis,
x.ndim,
),
)
for i in _axis:
for i in axes:
shape[i] = 1
res = xp.reshape(res, tuple(shape))

Expand Down
15 changes: 6 additions & 9 deletions array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,29 @@
# - docs: https://github.com/jorenham/optype/blob/master/README.md#just
# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py
@final
class JustInt(Protocol):
@property
class JustInt(Protocol): # type: ignore[misc]
@property # type: ignore[override]
def __class__(self, /) -> type[int]: ...
@__class__.setter
def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustFloat(Protocol):
@property
class JustFloat(Protocol): # type: ignore[misc]
@property # type: ignore[override]
def __class__(self, /) -> type[float]: ...
@__class__.setter
def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


@final
class JustComplex(Protocol):
@property
class JustComplex(Protocol): # type: ignore[misc]
@property # type: ignore[override]
def __class__(self, /) -> type[complex]: ...
@__class__.setter
def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride]


#


class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
def __len__(self, /) -> int: ...
Expand Down
31 changes: 14 additions & 17 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from builtins import bool as py_bool

import cupy as cp

Expand Down Expand Up @@ -67,18 +67,13 @@

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: (
Array
| bool | int | float | complex
| NestedSequence[bool | int | float | complex]
| SupportsBufferProtocol
),
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
copy: Optional[bool] = None,
**kwargs,
dtype: DType | None = None,
device: Device | None = None,
copy: py_bool | None = None,
**kwargs: object,
) -> Array:
"""
Array API compatibility wrapper for asarray().
Expand All @@ -101,8 +96,8 @@ def astype(
dtype: DType,
/,
*,
copy: bool = True,
device: Optional[Device] = None,
copy: py_bool = True,
device: Device | None = None,
) -> Array:
if device is None:
return x.astype(dtype=dtype, copy=copy)
Expand All @@ -113,8 +108,8 @@ def astype(
# cupy.count_nonzero does not have keepdims
def count_nonzero(
x: Array,
axis=None,
keepdims=False
axis: int | tuple[int, ...] | None = None,
keepdims: py_bool = False,
) -> Array:
result = cp.count_nonzero(x, axis)
if keepdims:
Expand All @@ -125,7 +120,7 @@ def count_nonzero(


# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
return cp.take_along_axis(x, indices, axis=axis)


Expand Down Expand Up @@ -153,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1):
'bool', 'concat', 'count_nonzero', 'pow', 'sign',
'take_along_axis']

_all_ignore = ['cp', 'get_xp']

def __dir__() -> list[str]:
return __all__
9 changes: 5 additions & 4 deletions array_api_compat/cupy/fft.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from cupy.fft import * # noqa: F403
from cupy.fft import * # noqa: F403

# cupy.fft doesn't have __all__. If it is added, replace this with
#
# from cupy.fft import __all__ as linalg_all
_n = {}
exec('from cupy.fft import *', _n)
del _n['__builtins__']
_n: dict[str, object] = {}
exec("from cupy.fft import *", _n)
del _n["__builtins__"]
fft_all = list(_n)
del _n

Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# cupy.linalg doesn't have __all__. If it is added, replace this with
#
# from cupy.linalg import __all__ as linalg_all
_n = {}
_n: dict[str, object] = {}
exec('from cupy.linalg import *', _n)
del _n['__builtins__']
linalg_all = list(_n)
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dask.array import * # noqa: F403

# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
from ._aliases import * # type: ignore[assignment] # noqa: F403

__array_api_version__: Final = "2024.12"

Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def arange(

# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
/,
*,
dtype: DType | None = None,
Expand Down
Loading
Loading