diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 90c82c2..3f418d8 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,7 +16,11 @@ """ -__array_api_version__ = "2022.12" +# Warning: __array_api_version__ could change globally with +# set_array_api_strict_flags(). This should always be accessed as an +# attribute, like xp.__array_api_version__, or using +# array_api_strict.get_array_api_strict_flags()['api_version']. +from ._flags import API_VERSION as __array_api_version__ __all__ = ["__array_api_version__"] @@ -244,7 +248,7 @@ __all__ += ["linalg"] -from .linalg import matmul, tensordot, matrix_transpose, vecdot +from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] @@ -284,6 +288,17 @@ __all__ += ["all", "any"] +# Helper functions that are not part of the standard + +from ._flags import ( + set_array_api_strict_flags, + get_array_api_strict_flags, + reset_array_api_strict_flags, + ArrayAPIStrictFlags, +) + +__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] + from . import _version __version__ = _version.get_versions()['version'] del _version diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 39808f0..2b9155a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -17,7 +17,6 @@ import operator from enum import IntEnum -import warnings from ._creation_functions import asarray from ._dtypes import ( @@ -32,6 +31,7 @@ _result_type, _dtype_categories, ) +from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types @@ -427,13 +427,17 @@ def _validate_index(self, key): "the Array API)" ) elif isinstance(i, Array): - if i.dtype in _boolean_dtypes and len(_key) != 1: - assert isinstance(key, tuple) # sanity check - raise IndexError( - f"Single-axes index {i} is a boolean array and " - f"{len(key)=}, but masking is only specified in the " - "Array API when the array is the sole index." - ) + if i.dtype in _boolean_dtypes: + if len(_key) != 1: + assert isinstance(key, tuple) # sanity check + raise IndexError( + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." + ) + if not get_array_api_strict_flags()['data_dependent_shapes']: + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( f"Single-axes index {i} is a non-zero-dimensional " @@ -482,10 +486,21 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None ) -> types.ModuleType: - if api_version is not None and api_version not in ["2021.12", "2022.12"]: - raise ValueError(f"Unrecognized array API version: {api_version!r}") - if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + """ + Return the array_api_strict namespace corresponding to api_version. + + The default API version is '2022.12'. Note that '2021.12' is supported, + but currently identical to '2022.12'. + + For array_api_strict, calling this function with api_version will set + the API version for the array_api_strict module globally. This can + also be achieved with the + {func}`array_api_strict.set_array_api_strict_flags` function. If you + want to only set the version locally, use the + {class}`array_api_strict.ArrayApiStrictFlags` context manager. + + """ + set_array_api_strict_flags(api_version=api_version) import array_api_strict return array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py new file mode 100644 index 0000000..6cc503a --- /dev/null +++ b/array_api_strict/_flags.py @@ -0,0 +1,300 @@ +""" +These functions configure global flags that allow array-api-strict to be +used in different "modes". These modes include + +- Changing to different supported versions of the standard. +- Enabling or disabling different optional behaviors (such as data-dependent + shapes). +- Enabling or disabling different optional extensions. + +None of these functions are part of the standard itself. A typical array API +library will only support one particular configuration of these flags. + +""" + +import functools +import os +import warnings + +import array_api_strict + +supported_versions = ( + "2021.12", + "2022.12", +) + +API_VERSION = default_version = "2022.12" + +DATA_DEPENDENT_SHAPES = True + +all_extensions = ( + "linalg", + "fft", +) + +extension_versions = { + "linalg": "2021.12", + "fft": "2022.12", +} + +ENABLED_EXTENSIONS = default_extensions = ( + "linalg", + "fft", +) +# Public functions + +def set_array_api_strict_flags( + *, + api_version=None, + data_dependent_shapes=None, + enabled_extensions=None, +): + """ + Set the array-api-strict flags to the specified values. + + Flags are global variables that enable or disable array-api-strict + behaviors. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + - `api_version`: The version of the standard to use. Supported + versions are: ``{supported_versions}``. The default version number is + ``{default_version!r}``. + + Note that 2021.12 is supported, but currently gives the same thing as + 2022.12 (except that the fft extension will be disabled). + + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in + array-api-strict. + + This flag is enabled by default. Array libraries that use computation + graphs may not be able to support functions whose output shapes depend + on the input data. + + The functions that make use of data-dependent shapes, and are therefore + disabled by setting this flag to False are + + - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. + - `nonzero` + - Boolean array indexing + - `repeat` when the `repeats` argument is an array (requires 2023.12 + version of the standard) + + See + https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html + for more details. + + - `enabled_extensions`: A list of extensions that are enabled in + array-api-strict. The default is ``{default_extensions}``. Note that + some extensions require a minimum version of the standard. + + The flags can also be changed by setting :ref:`environment variables + `. + + Examples + -------- + + >>> from array_api_strict import set_array_api_strict_flags + + >>> # Set the standard version to 2021.12 + >>> set_array_api_strict_flags(api_version="2021.12") + + >>> # Disable data-dependent shapes + >>> set_array_api_strict_flags(data_dependent_shapes=False) + + >>> # Enable only the linalg extension (disable the fft extension) + >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) + + See Also + -------- + + get_array_api_strict_flags: Get the current values of flags. + reset_array_api_strict_flags: Reset the flags to their default values. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. + + """ + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + + if api_version is not None: + if api_version not in supported_versions: + raise ValueError(f"Unsupported standard version {api_version!r}") + if api_version == "2021.12": + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") + API_VERSION = api_version + array_api_strict.__array_api_version__ = API_VERSION + + if data_dependent_shapes is not None: + DATA_DEPENDENT_SHAPES = data_dependent_shapes + + if enabled_extensions is not None: + for extension in enabled_extensions: + if extension not in all_extensions: + raise ValueError(f"Unsupported extension {extension}") + if extension_versions[extension] > API_VERSION: + raise ValueError( + f"Extension {extension} requires standard version " + f"{extension_versions[extension]} or later" + ) + ENABLED_EXTENSIONS = tuple(enabled_extensions) + else: + ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) + +# We have to do this separately or it won't get added as the docstring +set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( + supported_versions=supported_versions, + default_version=default_version, + default_extensions=default_extensions, +) + +def get_array_api_strict_flags(): + """ + Get the current array-api-strict flags. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + Returns + ------- + dict + A dictionary containing the current array-api-strict flags. + + Examples + -------- + + >>> from array_api_strict import get_array_api_strict_flags + >>> flags = get_array_api_strict_flags() + >>> flags + {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} + + See Also + -------- + + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. + + """ + return { + "api_version": API_VERSION, + "data_dependent_shapes": DATA_DEPENDENT_SHAPES, + "enabled_extensions": ENABLED_EXTENSIONS, + } + + +def reset_array_api_strict_flags(): + """ + Reset the array-api-strict flags to their default values. + + This will also reset any flags that were set by :ref:`environment + variables ` back to their default values. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + See :func:`set_array_api_strict_flags` for a list of flags and their + default values. + + Examples + -------- + + >>> from array_api_strict import reset_array_api_strict_flags + >>> reset_array_api_strict_flags() + + See Also + -------- + + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. + + """ + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + API_VERSION = default_version + array_api_strict.__array_api_version__ = API_VERSION + DATA_DEPENDENT_SHAPES = True + ENABLED_EXTENSIONS = default_extensions + + +class ArrayAPIStrictFlags: + """ + A context manager to temporarily set the array-api-strict flags. + + .. note:: + + This class is **not** part of the array API standard. It only exists + in array-api-strict. + + See :func:`set_array_api_strict_flags` for a + description of the available flags. + + See Also + -------- + + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. + + """ + def __init__(self, *, api_version=None, data_dependent_shapes=None, + enabled_extensions=None): + self.kwargs = { + "api_version": api_version, + "data_dependent_shapes": data_dependent_shapes, + "enabled_extensions": enabled_extensions, + } + self.old_flags = get_array_api_strict_flags() + + def __enter__(self): + set_array_api_strict_flags(**self.kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + set_array_api_strict_flags(**self.old_flags) + +# Private functions + +def set_flags_from_environment(): + if "ARRAY_API_STRICT_API_VERSION" in os.environ: + set_array_api_strict_flags( + api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] + ) + + if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: + set_array_api_strict_flags( + data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" + ) + + if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: + set_array_api_strict_flags( + enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + ) + +set_flags_from_environment() + +def requires_data_dependent_shapes(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not DATA_DEPENDENT_SHAPES: + raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper + +def requires_extension(extension): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if extension not in ENABLED_EXTENSIONS: + if extension == 'linalg' \ + and func.__name__ in ['matmul', 'tensordot', + 'matrix_transpose', 'vecdot']: + raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.") + raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper + return decorator diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py new file mode 100644 index 0000000..1ff08d4 --- /dev/null +++ b/array_api_strict/_linear_algebra_functions.py @@ -0,0 +1,68 @@ +""" +These functions are all also defined in the linalg extension, but we include +them here with wrappers in linalg so that the wrappers can be disabled if the +linalg extension is disabled in the flags. + +""" + +from __future__ import annotations + +from ._dtypes import _numeric_dtypes + +from ._array_object import Array + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._typing import Sequence, Tuple, Union + +import numpy.linalg +import numpy as np + +# Note: matmul is the numpy top-level namespace but not in np.linalg +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') + + return Array._new(np.matmul(x1._array, x2._array)) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg + +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) + +# Note: vecdot is not in NumPy +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index f4b2f56..9781531 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes +from ._flags import requires_data_dependent_shapes from typing import Optional, Tuple @@ -30,6 +31,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) +@requires_data_dependent_shapes def nonzero(x: Array, /) -> Tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index 0b4132c..e6ca939 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -2,6 +2,8 @@ from ._array_object import Array +from ._flags import requires_data_dependent_shapes + from typing import NamedTuple import numpy as np @@ -35,6 +37,7 @@ class UniqueInverseResult(NamedTuple): inverse_indices: Array +@requires_data_dependent_shapes def unique_all(x: Array, /) -> UniqueAllResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -59,6 +62,7 @@ def unique_all(x: Array, /) -> UniqueAllResult: ) +@requires_data_dependent_shapes def unique_counts(x: Array, /) -> UniqueCountsResult: res = np.unique( x._array, @@ -71,6 +75,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: return UniqueCountsResult(*[Array._new(i) for i in res]) +@requires_data_dependent_shapes def unique_inverse(x: Array, /) -> UniqueInverseResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -90,6 +95,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) +@requires_data_dependent_shapes def unique_values(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.unique `. diff --git a/array_api_strict/fft.py b/array_api_strict/fft.py index b50e9e3..7f427e5 100644 --- a/array_api_strict/fft.py +++ b/array_api_strict/fft.py @@ -15,9 +15,11 @@ ) from ._array_object import Array, CPU_DEVICE from ._data_type_functions import astype +from ._flags import requires_extension import numpy as np +@requires_extension('fft') def fft( x: Array, /, @@ -40,6 +42,7 @@ def fft( return astype(res, complex64) return res +@requires_extension('fft') def ifft( x: Array, /, @@ -62,6 +65,7 @@ def ifft( return astype(res, complex64) return res +@requires_extension('fft') def fftn( x: Array, /, @@ -84,6 +88,7 @@ def fftn( return astype(res, complex64) return res +@requires_extension('fft') def ifftn( x: Array, /, @@ -106,6 +111,7 @@ def ifftn( return astype(res, complex64) return res +@requires_extension('fft') def rfft( x: Array, /, @@ -128,6 +134,7 @@ def rfft( return astype(res, complex64) return res +@requires_extension('fft') def irfft( x: Array, /, @@ -150,6 +157,7 @@ def irfft( return astype(res, float32) return res +@requires_extension('fft') def rfftn( x: Array, /, @@ -172,6 +180,7 @@ def rfftn( return astype(res, complex64) return res +@requires_extension('fft') def irfftn( x: Array, /, @@ -194,6 +203,7 @@ def irfftn( return astype(res, float32) return res +@requires_extension('fft') def hfft( x: Array, /, @@ -216,6 +226,7 @@ def hfft( return astype(res, float32) return res +@requires_extension('fft') def ihfft( x: Array, /, @@ -238,6 +249,7 @@ def ihfft( return astype(res, complex64) return res +@requires_extension('fft') def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -248,6 +260,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.fftfreq(n, d=d)) +@requires_extension('fft') def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -258,6 +271,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.rfftfreq(n, d=d)) +@requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftshift `. @@ -268,6 +282,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: raise TypeError("Only floating-point dtypes are allowed in fftshift") return Array._new(np.fft.fftshift(x._array, axes=axes)) +@requires_extension('fft') def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.ifftshift `. diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 78e9ec4..e1998fa 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -12,6 +12,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array +from ._flags import requires_extension try: from numpy._core.numeric import normalize_axis_tuple @@ -46,6 +47,7 @@ class SVDResult(NamedTuple): # Note: the inclusion of the upper keyword is different from # np.linalg.cholesky, which does not have it. +@requires_extension('linalg') def cholesky(x: Array, /, *, upper: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.cholesky `. @@ -65,6 +67,7 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array: return Array._new(L) # Note: cross is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ Array API compatible wrapper for :py:func:`np.cross `. @@ -80,6 +83,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: raise ValueError('cross() dimension must equal 3') return Array._new(np.cross(x1._array, x2._array, axis=axis)) +@requires_extension('linalg') def det(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.det `. @@ -93,6 +97,7 @@ def det(x: Array, /) -> Array: return Array._new(np.linalg.det(x._array)) # Note: diagonal is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.diagonal `. @@ -103,7 +108,7 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: # operates on the first two axes by default return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) - +@requires_extension('linalg') def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh `. @@ -120,6 +125,7 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) +@requires_extension('linalg') def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh `. @@ -133,6 +139,7 @@ def eigvalsh(x: Array, /) -> Array: return Array._new(np.linalg.eigvalsh(x._array)) +@requires_extension('linalg') def inv(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.inv `. @@ -146,28 +153,13 @@ def inv(x: Array, /) -> Array: return Array._new(np.linalg.inv(x._array)) - -# Note: matmul is the numpy top-level namespace but not in np.linalg -def matmul(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.matmul `. - - See its docstring for more information. - """ - # Note: the restriction to numeric dtypes only is different from - # np.matmul. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in matmul') - - return Array._new(np.matmul(x1._array, x2._array)) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. +@requires_extension('linalg') def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -182,6 +174,7 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) +@requires_extension('linalg') def matrix_power(x: Array, n: int, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_power `. @@ -197,6 +190,7 @@ def matrix_power(x: Array, n: int, /) -> Array: return Array._new(np.linalg.matrix_power(x._array, n)) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank +@requires_extension('linalg') def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank `. @@ -219,14 +213,8 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A return Array._new(np.count_nonzero(S > tol, axis=-1)) -# Note: this function is new in the array API spec. Unlike transpose, it only -# transposes the last two axes. -def matrix_transpose(x: Array, /) -> Array: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) - # Note: outer is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def outer(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.outer `. @@ -245,6 +233,7 @@ def outer(x1: Array, x2: Array, /) -> Array: return Array._new(np.outer(x1._array, x2._array)) # Note: the keyword argument name rtol is different from np.linalg.pinv +@requires_extension('linalg') def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv `. @@ -262,6 +251,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps return Array._new(np.linalg.pinv(x._array, rcond=rtol)) +@requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: """ Array API compatible wrapper for :py:func:`np.linalg.qr `. @@ -277,6 +267,7 @@ def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRRe # np.linalg.qr, which only returns a tuple. return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) +@requires_extension('linalg') def slogdet(x: Array, /) -> SlogdetResult: """ Array API compatible wrapper for :py:func:`np.linalg.slogdet `. @@ -335,6 +326,7 @@ def _solve(a, b): return wrap(r.astype(result_t, copy=False)) +@requires_extension('linalg') def solve(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.solve `. @@ -348,6 +340,7 @@ def solve(x1: Array, x2: Array, /) -> Array: return Array._new(_solve(x1._array, x2._array)) +@requires_extension('linalg') def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: """ Array API compatible wrapper for :py:func:`np.linalg.svd `. @@ -365,23 +358,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). +@requires_extension('linalg') def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) -# Note: tensordot is the numpy top-level namespace but not in np.linalg - -# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: - # Note: the restriction to numeric dtypes only is different from - # np.tensordot. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in tensordot') - - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) - # Note: trace is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -404,29 +388,12 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) -# Note: vecdot is not in NumPy -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in vecdot') - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) - x1_ = np.moveaxis(x1_, axis, -1) - x2_ = np.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return Array._new(res[..., 0, 0]) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. +@requires_extension('linalg') def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -472,4 +439,35 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No return res +# These functions are also in the main namespace. We define them here as +# wrappers so that they can still be disabled when the linalg extension is +# disabled without disabling the versions in the main namespace. + +# Note: matmul is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + from ._linear_algebra_functions import matmul + return matmul(x1, x2) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + from ._linear_algebra_functions import tensordot + return tensordot(x1, x2, axes=axes) + +@requires_extension('linalg') +def matrix_transpose(x: Array, /) -> Array: + from ._linear_algebra_functions import matrix_transpose + return matrix_transpose(x) + +@requires_extension('linalg') +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + from ._linear_algebra_functions import vecdot + return vecdot(x1, x2, axis=axis) + __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py new file mode 100644 index 0000000..5000d5d --- /dev/null +++ b/array_api_strict/tests/conftest.py @@ -0,0 +1,9 @@ +from .._flags import reset_array_api_strict_flags + +import pytest + +@pytest.fixture(autouse=True) +def reset_flags(): + reset_array_api_strict_flags() + yield + reset_array_api_strict_flags() diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a0b6132..bae0553 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -402,9 +402,17 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version=None) is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version="2022.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2021.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py new file mode 100644 index 0000000..303c930 --- /dev/null +++ b/array_api_strict/tests/test_flags.py @@ -0,0 +1,188 @@ +from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, + reset_array_api_strict_flags) + +from .. import (asarray, unique_all, unique_counts, unique_inverse, + unique_values, nonzero) + +import array_api_strict as xp + +import pytest + +def test_flags(): + # Test defaults + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + # Test setting flags + set_array_api_strict_flags(data_dependent_shapes=False) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg', 'fft'), + } + set_array_api_strict_flags(enabled_extensions=('fft',)) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('fft',), + } + # Make sure setting the version to 2021.12 disables fft and issues a + # warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2021.12') + assert len(record) == 1 + assert '2021.12' in str(record[0].message) + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2021.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg',), + } + + # Test setting flags with invalid values + pytest.raises(ValueError, lambda: + set_array_api_strict_flags(api_version='2020.12')) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + enabled_extensions=('linalg', 'fft', 'invalid'))) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + api_version='2021.12', + enabled_extensions=('linalg', 'fft'))) + + # Test resetting flags + with pytest.warns(UserWarning): + set_array_api_strict_flags( + api_version='2021.12', + data_dependent_shapes=False, + enabled_extensions=()) + reset_array_api_strict_flags() + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + +def test_api_version(): + # Test defaults + assert xp.__array_api_version__ == '2022.12' + + # Test setting the version + set_array_api_strict_flags(api_version='2021.12') + assert xp.__array_api_version__ == '2021.12' + +def test_data_dependent_shapes(): + a = asarray([0, 0, 1, 2, 2]) + mask = asarray([True, False, True, False, True]) + + # Should not error + unique_all(a) + unique_counts(a) + unique_inverse(a) + unique_values(a) + nonzero(a) + a[mask] + # TODO: add repeat when it is implemented + + set_array_api_strict_flags(data_dependent_shapes=False) + + pytest.raises(RuntimeError, lambda: unique_all(a)) + pytest.raises(RuntimeError, lambda: unique_counts(a)) + pytest.raises(RuntimeError, lambda: unique_inverse(a)) + pytest.raises(RuntimeError, lambda: unique_values(a)) + pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: a[mask]) + +linalg_examples = { + 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), + 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: xp.linalg.det(xp.eye(3)), + 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), + 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), + 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), + 'inv': lambda: xp.linalg.inv(xp.eye(3)), + 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), + 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), + 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), + 'qr': lambda: xp.linalg.qr(xp.eye(3)), + 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), + 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: xp.linalg.svd(xp.eye(3)), + 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), + 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: xp.linalg.trace(xp.eye(3)), + 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), +} + +assert set(linalg_examples) == set(xp.linalg.__all__) + +linalg_main_namespace_examples = { + 'matmul': lambda: xp.matmul(xp.eye(3), xp.eye(3)), + 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), + 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), + 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), +} + +assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) + +@pytest.mark.parametrize('func_name', linalg_examples.keys()) +def test_linalg(func_name): + func = linalg_examples[func_name] + if func_name in linalg_main_namespace_examples: + main_namespace_func = linalg_main_namespace_examples[func_name] + else: + main_namespace_func = lambda: None + + # First make sure the example actually works + func() + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + func() + main_namespace_func() + +fft_examples = { + 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: xp.fft.fftfreq(4), + 'rfftfreq': lambda: xp.fft.rfftfreq(4), + 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), +} + +assert set(fft_examples) == set(xp.fft.__all__) + +@pytest.mark.parametrize('func_name', fft_examples.keys()) +def test_fft(func_name): + func = fft_examples[func_name] + + # First make sure the example actually works + func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + + set_array_api_strict_flags(enabled_extensions=('fft',)) + func() diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..e703a63 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,48 @@ +API Reference +============= + +.. automodule:: array_api_strict + +Array API Strict Flags +---------------------- + +.. automodule:: array_api_strict._flags + +.. currentmodule:: array_api_strict + +.. autofunction:: get_array_api_strict_flags +.. autofunction:: set_array_api_strict_flags +.. autofunction:: reset_array_api_strict_flags +.. autoclass:: ArrayAPIStrictFlags + +.. _environment-variables: + +Environment Variables +~~~~~~~~~~~~~~~~~~~~~ + +Flags can also be set with environment variables. +:func:`set_array_api_strict_flags` will override the values set by environment +variables. Note that the environment variables will only change the defaults +used by array-api-strict initially. They will not change the defaults used by +:func:`reset_array_api_strict_flags`. + +.. envvar:: ARRAY_API_STRICT_API_VERSION + + A string representing the version number. + +.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + + "True" or "False" to enable or disable data dependent shapes. + +.. envvar:: ARRAY_API_STRICT_ENABLED_EXTENSIONS + + A comma separated list of extensions to enable. + +Array API Functions +-------------------- + +All functions and methods in +the array API standard are implemented in array-api-strict. See the `Array API +Standard +`__ for +full documentation for each function. diff --git a/docs/changelog.md b/docs/changelog.md index 8f1c203..04c383d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -20,7 +20,7 @@ This is the first release of `array_api_strict`. It is extracted from `numpy.array_api`, which was included as an experimental submodule in NumPy versions prior to 2.0. Note that the commit history in this repository is -extracted from the git history of numpy/array_api/ (see the [README](README.md)). +extracted from the git history of numpy/array_api/ (see [](numpy.array_api)). Additionally, the following changes are new to `array_api_strict` from `numpy.array_api` in NumPy 1.26 (the last NumPy feature release to include diff --git a/docs/conf.py b/docs/conf.py index c068b06..e4c66d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,8 +22,8 @@ extensions = [ 'myst_parser', - # 'sphinx.ext.autodoc', - # 'sphinx.ext.napoleon', + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', # 'sphinx.ext.intersphinx', 'sphinx_copybutton', ] diff --git a/docs/index.md b/docs/index.md index 307a9c2..6e84efa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -183,6 +183,7 @@ issue, but this hasn't necessarily been tested thoroughly. API standard. [Support for 2023.12 is planned](https://github.com/data-apis/array-api-strict/issues/25). +(numpy.array_api)= ## Relationship to `numpy.array_api` Previously this implementation was available as `numpy.array_api`, but it was @@ -201,5 +202,6 @@ git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_a :titlesonly: :hidden: +api.rst changelog.md ```