From b5582a70caa2df2bb77fefb8c05aec1caccea28c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 27 Feb 2025 14:09:18 +0000 Subject: [PATCH 01/13] ENH: Type annotations overhaul --- array_api_compat/common/_aliases.py | 234 +++++++++++++----------- array_api_compat/common/_fft.py | 87 ++++----- array_api_compat/common/_helpers.py | 32 ++-- array_api_compat/common/_linalg.py | 84 +++++---- array_api_compat/common/_typing.py | 16 +- array_api_compat/cupy/_aliases.py | 25 ++- array_api_compat/cupy/_typing.py | 67 +++---- array_api_compat/dask/array/_aliases.py | 46 ++--- array_api_compat/dask/array/fft.py | 13 +- array_api_compat/dask/array/linalg.py | 25 +-- array_api_compat/numpy/_aliases.py | 26 ++- array_api_compat/numpy/_typing.py | 67 +++---- array_api_compat/torch/_aliases.py | 157 ++++++++-------- array_api_compat/torch/_typing.py | 4 + array_api_compat/torch/fft.py | 35 ++-- array_api_compat/torch/linalg.py | 28 ++- tests/test_all.py | 17 +- 17 files changed, 487 insertions(+), 476 deletions(-) create mode 100644 array_api_compat/torch/_typing.py diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 98b8e425..9f771f41 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,15 +4,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace +from ._typing import Array, Device, DType, Namespace # These functions are modified from the NumPy versions. @@ -24,29 +20,34 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) @@ -55,37 +56,37 @@ def eye( n_cols: Optional[int] = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( - x: ndarray, + x: Array, /, fill_value: Union[int, float], *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) @@ -95,48 +96,58 @@ def linspace( /, num: int, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -150,23 +161,23 @@ def zeros_like( # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, Any]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. @@ -175,7 +186,7 @@ def _unique_kwargs(xp): return {'equal_nan': False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, @@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray: # These functions have different keyword argument names def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. @@ -294,15 +305,15 @@ def cumulative_sum( def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: @@ -325,17 +336,18 @@ def cumulative_prod( # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: + out: Optional[Array] = None, +) -> Array: def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -389,15 +401,18 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: Tuple[int, ...], + xp: Namespace, + copy: Optional[bool] = None, + **kwargs, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -409,9 +424,15 @@ def reshape(x: ndarray, # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -434,9 +455,15 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -448,50 +475,50 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: +def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: +def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: +def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, +def tensordot(x1: Array, + x2: Array, /, - xp, + xp: Namespace, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs, -) -> ndarray: +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") @@ -510,8 +537,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: # isdtype is a new function in the 2022.12 array API specification. def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: Union[DType, str, Tuple[Union[DType, str], ...]], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -550,14 +580,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: +def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..bd2a4e1a 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,148 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Union, Optional, Literal -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Device, Array, DType, Namespace # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +151,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +167,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +180,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..6d95069d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,14 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace - import sys import math import inspect import warnings +from typing import Optional, Union, Any + +from ._typing import Array, Device, Namespace + def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. @@ -268,7 +266,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') -def is_numpy_namespace(xp) -> bool: +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} -def is_cupy_namespace(xp) -> bool: +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} -def is_torch_namespace(xp) -> bool: +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool: return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp) -> bool: +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool: return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool: return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool: return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +def array_namespace( + *xs: Union[Array, bool, int, float, complex, None], + api_version: Optional[str] = None, + use_compat: Optional[bool] = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..c77ee3b8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, Optional, Tuple, Union import numpy as np if np.__version__[0] == "2": @@ -15,50 +11,53 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._typing import Array, Namespace # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', +def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) @@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, +def matrix_rank(x: Array, /, - xp, + xp: Namespace, *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: + rtol: Optional[Union[float, Array]] = None, + **kwargs) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: @@ -88,7 +87,9 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float]] = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..4c3b356b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,24 @@ from __future__ import annotations +from types import ModuleType as Namespace +from typing import Any, TypeVar, Protocol __all__ = [ + "Array", + "DType", + "Device", + "Namespace", "NestedSequence", "SupportsBufferProtocol", ] -from types import ModuleType -from typing import ( - Any, - TypeVar, - Protocol, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any +SupportsBufferProtocol = Any Array = Any Device = Any DType = Any -Namespace = ModuleType diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..6f03b6ce 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from typing import Optional, Union + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -69,20 +67,21 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: Union[ - ndarray, + Array, bool, int, float, - NestedSequence[bool | int | float], + complex, + NestedSequence[bool | int | float | complex], SupportsBufferProtocol, ], /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -112,13 +111,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..07c11f78 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,33 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["cp"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import Union +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype +try: + DType = cp.dtype[ + Union[ + cp.intp, + cp.int8, + cp.int16, + cp.int32, + cp.int64, + cp.uint8, + cp.uint16, + cp.uint32, + cp.uint64, + cp.float32, + cp.float64, + cp.complex64, + cp.complex128, + cp.bool_, + ] + ] +except TypeError: + # NumPy 1.x on Python 3.9 and 3.10 + DType = cp.dtype diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..1c922f3a 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,16 +1,10 @@ from __future__ import annotations -from typing import Callable - -from ...common import _aliases, array_namespace - -from ..._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Callable, Optional, Union import numpy as np from numpy import ( - # Dtypes + # dtypes iinfo, finfo, bool_ as bool, @@ -29,22 +23,19 @@ can_cast, result_type, ) - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - import dask.array as da +from ...common import _aliases, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) +from ..._internal import get_xp +from ._info import __array_namespace_info__ + isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,7 +43,7 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, copy: bool = True, @@ -84,7 +75,7 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, ) -> Array: @@ -149,12 +140,13 @@ def asarray( bool, int, float, - NestedSequence[bool | int | float], + complex, + NestedSequence[bool | int | float | complex], SupportsBufferProtocol, ], /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, @@ -360,4 +352,4 @@ def count_nonzero( 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] +_all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..3f40dffe 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -4,9 +4,10 @@ # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) -del _n['__builtins__'] +for k in ("__builtins__", "Sequence", "annotations", "warnings"): + _n.pop(k, None) fft_all = list(_n) -del _n +del _n, k from ...common import _fft from ..._internal import get_xp @@ -16,9 +17,5 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] - -del get_xp -del da -del fft_all -del _fft +__all__ = fft_all + ["fftfreq", "rfftfreq"] +_all_ignore = ["da", "fft_all", "get_xp", "warnings"] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..bd53f0df 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,28 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal +import dask.array as da # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer - # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +from ..._internal import get_xp +from ...common import _linalg +from ...common._typing import Array +from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] +for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): + _n.pop(k, None) linalg_all = list(_n) -del _n +del _n, k EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..aedaa7a3 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,15 +1,12 @@ from __future__ import annotations -from ..common import _aliases +from typing import Optional, Union from .._internal import get_xp - +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType import numpy as np bool = np.bool_ @@ -77,20 +74,21 @@ def _supports_buffer_protocol(obj): # rather than trying to combine everything into one function in common/ def asarray( obj: Union[ - ndarray, + Array, bool, int, float, - NestedSequence[bool | int | float], + complex, + NestedSequence[bool | int | float | complex], SupportsBufferProtocol, ], /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -117,13 +115,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: return x.astype(dtype=dtype, copy=copy) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..5557c5e9 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,33 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) +from typing import Literal, Union -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +import numpy as np +from numpy import ndarray as Array Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] -else: - Dtype = dtype +try: + DType = np.dtype[ + Union[ + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.float32, + np.float64, + np.complex64, + np.complex128, + np.bool_, + ] + ] +except TypeError: + # NumPy 1.x on Python 3.9 and 3.10 + DType = np.dtype diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index b4786320..cb3c3bad 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,21 +2,13 @@ from functools import wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any +from typing import List, Optional, Sequence, Tuple, Union -from ..common import _aliases +import torch from .._internal import get_xp - +from ..common import _aliases from ._info import __array_namespace_info__ - -import torch - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -123,7 +115,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: +def result_type(*arrays_and_dtypes: Union[Array, DType, bool, int, float, complex]) -> DType: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") if len(arrays_and_dtypes) == 1: @@ -151,7 +143,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -197,13 +189,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -216,7 +208,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -261,13 +253,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -297,13 +289,13 @@ def prod(x: array, return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -328,12 +320,12 @@ def sum(x: array, return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -353,12 +345,12 @@ def any(x: array, # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -378,12 +370,12 @@ def all(x: array, # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -395,13 +387,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -426,13 +418,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -455,11 +447,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0, - **kwargs) -> array: + **kwargs) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -468,7 +460,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -482,27 +474,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -510,25 +502,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: @@ -538,17 +530,16 @@ def count_nonzero( return result - -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: Array, x1: Array, x2: Array, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, shape: Tuple[int, ...], copy: Optional[bool] = None, - **kwargs) -> array: + **kwargs) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -562,9 +553,9 @@ def arange(start: Union[int, float], stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -583,9 +574,9 @@ def eye(n_rows: int, /, *, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -599,10 +590,10 @@ def linspace(start: Union[int, float], /, num: int, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) @@ -612,9 +603,9 @@ def linspace(start: Union[int, float], def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[bool, int, float, complex], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if isinstance(shape, int): shape = (shape,) @@ -623,52 +614,52 @@ def full(shape: Union[int, Tuple[int, ...]], # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> array: +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -678,7 +669,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -691,7 +682,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -699,14 +690,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -714,12 +705,12 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -727,7 +718,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ @@ -762,7 +753,7 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -770,11 +761,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..29ad3fa7 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,4 @@ +__all__ = ["Array", "DType", "Device"] + +from torch import dtype as DType, Tensor as Array +from ..common._typing import Device diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..50e6a0d0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,76 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from typing import Union, Sequence, Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from torch.fft import * # noqa: F403 + +from ._typing import Array # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..7b59a670 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') - -from ._aliases import _fix_promotion, sum +import torch +from typing import Optional, Union, Tuple from torch.linalg import * # noqa: F403 @@ -19,15 +12,17 @@ # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") @@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + # float stands for inf | -inf, which are not valid for Literal + ord: Union[int, float, float] = 2, **kwargs, -) -> array: +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') diff --git a/tests/test_all.py b/tests/test_all.py index 10a2a95d..a0a7e2b0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,16 @@ from ._helpers import import_, wrapped_libraries import pytest +import typing + +TYPING_NAMES = { + "Array", + "Device", + "DType", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +} @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): @@ -37,8 +47,11 @@ def test_all(library): dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] + ignore_all_names = set(getattr(module, '_all_ignore', ())) + ignore_all_names |= set(dir(typing)) + ignore_all_names |= {"annotations"} + if not module.__name__.endswith("._typing"): + ignore_all_names |= TYPING_NAMES dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__ From 8d888d5104b145708c376439c7e12478e701e75a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 27 Feb 2025 15:40:58 +0000 Subject: [PATCH 02/13] Re-add py.typed --- array_api_compat/py.typed | 0 setup.py | 5 ++++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 array_api_compat/py.typed diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/setup.py b/setup.py index 3d2b68a2..2368ccc4 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,8 @@ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - ] + ], + package_data={ + "array_api_compat": ["py.typed"], + }, ) From 4e4e84ea79eca4d7448ee6432f68ca893806db95 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:33:43 +0000 Subject: [PATCH 03/13] code review --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_aliases.py | 4 +-- array_api_compat/cupy/_typing.py | 38 ++++++++++++++--------------- array_api_compat/numpy/_aliases.py | 4 +-- array_api_compat/numpy/_typing.py | 38 ++++++++++++++--------------- array_api_compat/torch/_aliases.py | 1 + tests/test_all.py | 4 +-- 7 files changed, 44 insertions(+), 47 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 9f771f41..8eea4e47 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -177,7 +177,7 @@ class UniqueInverseResult(NamedTuple): inverse_indices: Array -def _unique_kwargs(xp: Namespace) -> dict[str, Any]: +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 6f03b6ce..88553ff6 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -126,10 +126,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, + x: Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index 07c11f78..87259285 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -3,31 +3,29 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["cp"] -from typing import Union +from typing import Union, TYPE_CHECKING import cupy as cp from cupy import ndarray as Array from cupy.cuda.device import Device -try: +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] DType = cp.dtype[ - Union[ - cp.intp, - cp.int8, - cp.int16, - cp.int32, - cp.int64, - cp.uint8, - cp.uint16, - cp.uint32, - cp.uint64, - cp.float32, - cp.float64, - cp.complex64, - cp.complex128, - cp.bool_, - ] + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ ] -except TypeError: - # NumPy 1.x on Python 3.9 and 3.10 +else: DType = cp.dtype diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index aedaa7a3..816de4bf 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -128,10 +128,10 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 def count_nonzero( - x : ndarray, + x : Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 5557c5e9..6a18a3b2 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -3,31 +3,29 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["np"] -from typing import Literal, Union +from typing import Literal, TYPE_CHECKING import numpy as np from numpy import ndarray as Array Device = Literal["cpu"] -try: +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] DType = np.dtype[ - Union[ - np.intp, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.float32, - np.float64, - np.complex64, - np.complex128, - np.bool_, - ] + np.intp + | np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + | np.bool ] -except TypeError: - # NumPy 1.x on Python 3.9 and 3.10 +else: DType = np.dtype diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index cb3c3bad..e5ff11e0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -5,6 +5,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch + from .._internal import get_xp from ..common import _aliases from ._info import __array_namespace_info__ diff --git a/tests/test_all.py b/tests/test_all.py index a0a7e2b0..eeb67e4b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -17,14 +17,14 @@ import pytest import typing -TYPING_NAMES = { +TYPING_NAMES = frozenset(( "Array", "Device", "DType", "Namespace", "NestedSequence", "SupportsBufferProtocol", -} +)) @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): From 082c0525fbe7e9f22c68f8527cf16dd545532f07 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:35:31 +0000 Subject: [PATCH 04/13] lint --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_typing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8eea4e47..f88bcd09 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace from ._typing import Array, Device, DType, Namespace diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index 87259285..66af5d19 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -3,7 +3,7 @@ __all__ = ["Array", "DType", "Device"] _all_ignore = ["cp"] -from typing import Union, TYPE_CHECKING +from typing import TYPE_CHECKING import cupy as cp from cupy import ndarray as Array From ac2cb733e5d70a2eb06c253ec694dfa37b4bf317 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:55:17 +0000 Subject: [PATCH 05/13] asarray --- array_api_compat/cupy/_aliases.py | 10 +--------- array_api_compat/dask/array/_aliases.py | 10 +--------- array_api_compat/numpy/_aliases.py | 10 +--------- 3 files changed, 3 insertions(+), 27 deletions(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 88553ff6..1e7f59f0 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -66,15 +66,7 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - complex, - NestedSequence[bool | int | float | complex], - SupportsBufferProtocol, - ], + obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 1c922f3a..d8e4cb84 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -135,15 +135,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - complex, - NestedSequence[bool | int | float | complex], - SupportsBufferProtocol, - ], + obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 816de4bf..6fec1cce 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -73,15 +73,7 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - Array, - bool, - int, - float, - complex, - NestedSequence[bool | int | float | complex], - SupportsBufferProtocol, - ], + obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, From 84e28e7acd54a6e9d47f896710a4b85b786409c8 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:55:32 +0000 Subject: [PATCH 06/13] fill_value --- array_api_compat/common/_aliases.py | 4 ++-- array_api_compat/torch/_aliases.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index f88bcd09..d34f3c64 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -67,7 +67,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], + fill_value: bool | complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -80,7 +80,7 @@ def full( def full_like( x: Array, /, - fill_value: Union[int, float], + fill_value: bool | complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index e5ff11e0..341a4db6 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -602,7 +602,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], + fill_value: bool | complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, From 8eaf862fde2c6cae3b80358e9570fc86926af04c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:55:42 +0000 Subject: [PATCH 07/13] result_type --- array_api_compat/torch/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 341a4db6..07d97a78 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[Array, DType, bool, int, float, complex]) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | bool | complex) -> DType: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") if len(arrays_and_dtypes) == 1: From 40de1c9ca993b787d026383ddfaba6bb2dbfd8e9 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 4 Mar 2025 23:56:18 +0000 Subject: [PATCH 08/13] lint --- array_api_compat/cupy/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 1e7f59f0..4145ae87 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Optional import cupy as cp From a213ab5494885df0075340b8bae8a755aaca0be3 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 5 Mar 2025 00:02:19 +0000 Subject: [PATCH 09/13] Arrays don't need to support buffer protocol --- array_api_compat/cupy/_aliases.py | 3 ++- array_api_compat/dask/array/_aliases.py | 3 ++- array_api_compat/numpy/_aliases.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 4145ae87..b76fef5e 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -66,7 +66,8 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, + obj: Array | bool | complex | NestedSequence[bool | complex] + | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d8e4cb84..9059db7e 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -135,7 +135,8 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, + obj: Array | bool | complex | NestedSequence[bool | complex] + | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6fec1cce..2ea4e5aa 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -73,7 +73,8 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol, + obj: Array | bool | complex | NestedSequence[bool | complex] + | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, From 0adac278eb55968448fdaa9a86417e2f0428d526 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 5 Mar 2025 00:14:21 +0000 Subject: [PATCH 10/13] bool is a subclass of int --- array_api_compat/common/_aliases.py | 4 ++-- array_api_compat/cupy/_aliases.py | 3 +-- array_api_compat/dask/array/_aliases.py | 3 +-- array_api_compat/numpy/_aliases.py | 3 +-- array_api_compat/torch/_aliases.py | 4 ++-- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index d34f3c64..07267956 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -67,7 +67,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: bool | complex, + fill_value: complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -80,7 +80,7 @@ def full( def full_like( x: Array, /, - fill_value: bool | complex, + fill_value: complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index b76fef5e..7b0ba6bb 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -66,8 +66,7 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Array | bool | complex | NestedSequence[bool | complex] - | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9059db7e..1244e603 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -135,8 +135,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Array | bool | complex | NestedSequence[bool | complex] - | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 2ea4e5aa..73ce9518 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -73,8 +73,7 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Array | bool | complex | NestedSequence[bool | complex] - | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 07d97a78..5e354692 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Array | DType | bool | complex) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: if len(arrays_and_dtypes) == 0: raise TypeError("At least one array or dtype must be provided") if len(arrays_and_dtypes) == 1: @@ -602,7 +602,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: bool | complex, + fill_value: complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, From 8fe4205d95f389c88b4cc0e35d3bb5070849822f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 5 Mar 2025 00:20:21 +0000 Subject: [PATCH 11/13] reshape: copy kwarg is keyword-only --- array_api_compat/common/_aliases.py | 1 + array_api_compat/torch/_aliases.py | 1 + 2 files changed, 2 insertions(+) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 07267956..cc2512a0 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -410,6 +410,7 @@ def reshape( /, shape: Tuple[int, ...], xp: Namespace, + *, copy: Optional[bool] = None, **kwargs, ) -> Array: diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5e354692..596c9780 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -539,6 +539,7 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: def reshape(x: Array, /, shape: Tuple[int, ...], + *, copy: Optional[bool] = None, **kwargs) -> Array: if copy is not None: From 646fc617e512be642fe69ef3fa8c78a9faf7cd9b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 5 Mar 2025 00:26:43 +0000 Subject: [PATCH 12/13] tensordot formatting --- array_api_compat/common/_aliases.py | 15 ++++++++------- array_api_compat/torch/_aliases.py | 9 ++++++++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index cc2512a0..132ae10f 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -509,13 +509,14 @@ def matrix_transpose(x: Array, /, xp: Namespace) -> Array: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: Array, - x2: Array, - /, - xp: Namespace, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 596c9780..26b06310 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -712,7 +712,14 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> Array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) From 1dc495538156215a70084f7382f59471860593a6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 21 Mar 2025 00:34:45 +0000 Subject: [PATCH 13/13] Reinstate explicit bool | complex --- array_api_compat/cupy/_aliases.py | 5 ++++- array_api_compat/dask/array/_aliases.py | 4 +++- array_api_compat/numpy/_aliases.py | 13 +++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 7b0ba6bb..ebc7ccd9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -64,9 +64,12 @@ _copy_default = object() + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 1244e603..e737cebd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -135,7 +135,9 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, dtype: Optional[DType] = None, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 73ce9518..6536d9a8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,6 +9,7 @@ from ._typing import Array, Device, DType import numpy as np + bool = np.bool_ # Basic renames @@ -61,6 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -68,12 +70,15 @@ def _supports_buffer_protocol(obj): return False return True + # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, dtype: Optional[DType] = None, @@ -119,11 +124,7 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero( - x : Array, - axis=None, - keepdims=False -) -> Array: +def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result)