From 2cd4c3f58f769c73a4c905593d90d678517d41da Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 8 Apr 2024 16:03:25 -0600 Subject: [PATCH 01/19] Set up basic structure for array-api-strict flags Flags are global variables that set array-api-strict in a specific mode. Currently support flags change the support array API standard version, enable or disable data-dependent shapes, and enable or disable optional extensions. This commit only sets up the structure for setting and getting these flags. --- array_api_strict/__init__.py | 11 ++ array_api_strict/_flags.py | 256 +++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 array_api_strict/_flags.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 90c82c2..1fab8a2 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -284,6 +284,17 @@ __all__ += ["all", "any"] +# Helper functions that are not part of the standard + +from ._flags import ( + set_array_api_strict_flags, + get_array_api_strict_flags, + reset_array_api_strict_flags, + ArrayApiStrictFlags, +) + +__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayApiStrictFlags'] + from . import _version __version__ = _version.get_versions()['version'] del _version diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py new file mode 100644 index 0000000..0d76903 --- /dev/null +++ b/array_api_strict/_flags.py @@ -0,0 +1,256 @@ +""" +This file defines flags for that allow array-api-strict to be used in +different "modes". These modes include + +- Changing to different supported versions of the standard. +- Enabling or disabling different optional behaviors (such as data-dependent + shapes). +- Enabling or disabling different optional extensions. + +Nothing in this file is part of the standard itself. A typical array API +library will only support one particular configuration of these flags. +""" + +import os + +supported_versions = [ + "2021.12", + "2022.12", +] + +STANDARD_VERSION = "2022.12" + +DATA_DEPENDENT_SHAPES = True + +all_extensions = [ + "linalg", + "fft", +] + +extension_versions = { + "linalg": "2021.12", + "fft": "2022.12", +} + +ENABLED_EXTENSIONS = [ + "linalg", + "fft", +] + +def set_array_api_strict_flags( + *, + standard_version=None, + data_dependent_shapes=None, + enabled_extensions=None, +): + """ + Set the array-api-strict flags to the specified values. + + Flags are global variables that enable or disable array-api-strict + behaviors. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + - `standard_version`: The version of the standard to use. Supported + versions are: ``{supported_versions}``. The default version number is + ``{default_version!r}``. + + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in + array-api-strict. This flag is enabled by default. Array libraries that + use computation graphs may not be able to support functions whose output + shapes depend on the input data. + + This flag is enabled by default. Array libraries that use computation graphs may not be able to support + functions whose output shapes depend on the input data. + + The functions that make use of data-dependent shapes, and are therefore + disabled by setting this flag to False are + + - `unique_all`, `unique_counts`, `unique_inverse`, and `unique_values`. + - `nonzero` + - Boolean array indexing + - `repeat` when the `repeats` argument is an array (requires 2023.12 + version of the standard) + + See + https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html + for more details. + + - `enabled_extensions`: A list of extensions that are enabled in + array-api-strict. The default is ``{default_extensions}``. Note that + some extensions require a minimum version of the standard. + + The default values of the flags can also be changed by setting environment + variables: + + - ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number. + - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". + - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of + extensions to enable. + + Examples + -------- + + >>> from array_api_strict import set_array_api_strict_flags + >>> # Set the standard version to 2021.12 + >>> set_array_api_strict_flags(standard_version="2021.12") + >>> # Disable data-dependent shapes + >>> set_array_api_strict_flags(data_dependent_shapes=False) + >>> # Enable only the linalg extension (disable the fft extension) + >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) + + See Also + -------- + + get_array_api_strict_flags + reset_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + + if standard_version is not None: + if standard_version not in supported_versions: + raise ValueError(f"Unsupported standard version {standard_version}") + STANDARD_VERSION = standard_version + + if data_dependent_shapes is not None: + DATA_DEPENDENT_SHAPES = data_dependent_shapes + + if enabled_extensions is not None: + for extension in enabled_extensions: + if extension not in all_extensions: + raise ValueError(f"Unsupported extension {extension}") + if extension_versions[extension] > STANDARD_VERSION: + raise ValueError( + f"Extension {extension} requires standard version " + f"{extension_versions[extension]} or later" + ) + ENABLED_EXTENSIONS = enabled_extensions + +# We have to do this separately or it won't get added as the docstring +set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( + supported_versions=supported_versions, + default_version=STANDARD_VERSION, + default_extensions=ENABLED_EXTENSIONS, +) + +def get_array_api_strict_flags(): + """ + Get the current array-api-strict flags. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + Returns + ------- + dict + A dictionary containing the current array-api-strict flags. + + Examples + -------- + + >>> from array_api_strict import get_array_api_strict_flags + >>> flags = get_array_api_strict_flags() + >>> flags + {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']} + + See Also + -------- + + set_array_api_strict_flags + reset_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + return { + "standard_version": STANDARD_VERSION, + "data_dependent_shapes": DATA_DEPENDENT_SHAPES, + "enabled_extensions": ENABLED_EXTENSIONS, + } + + +def reset_array_api_strict_flags(): + """ + Reset the array-api-strict flags to their default values. + + .. note:: + + This function is **not** part of the array API standard. It only exists + in array-api-strict. + + Examples + -------- + + >>> from array_api_strict import reset_array_api_strict_flags + >>> reset_array_api_strict_flags() + + See Also + -------- + + set_array_api_strict_flags + get_array_api_strict_flags + ArrayApiStrictFlags: A context manager to temporarily set the flags. + + """ + global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + STANDARD_VERSION = "2022.12" + DATA_DEPENDENT_SHAPES = True + ENABLED_EXTENSIONS = ["linalg", "fft"] + + +class ArrayApiStrictFlags: + """ + A context manager to temporarily set the array-api-strict flags. + + .. note:: + + This class is **not** part of the array API standard. It only exists + in array-api-strict. + + See :func:`~.array_api_strict.set_array_api_strict_flags` for a + description of the available flags. + + See Also + -------- + + set_array_api_strict_flags + get_array_api_strict_flags + reset_array_api_strict_flags + + """ + def __init__(self, *, standard_version=None, data_dependent_shapes=None, + enabled_extensions=None): + self.kwargs = { + "standard_version": standard_version, + "data_dependent_shapes": data_dependent_shapes, + "enabled_extensions": enabled_extensions, + } + self.old_flags = get_array_api_strict_flags() + + def __enter__(self): + set_array_api_strict_flags(**self.kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + set_array_api_strict_flags(**self.old_flags) + +# Set the flags from the environment variables +if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + set_array_api_strict_flags( + standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + ) + +if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: + set_array_api_strict_flags( + data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" + ) + +if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: + set_array_api_strict_flags( + enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + ) From d8c3745372e2d87b0be8bc87c5709246fe8d455a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:27:57 -0600 Subject: [PATCH 02/19] Disable extensions when setting the standard version --- array_api_strict/_flags.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 0d76903..a26258e 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -129,7 +129,9 @@ def set_array_api_strict_flags( f"Extension {extension} requires standard version " f"{extension_versions[extension]} or later" ) - ENABLED_EXTENSIONS = enabled_extensions + ENABLED_EXTENSIONS = tuple(enabled_extensions) + else: + ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= STANDARD_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( From f34576c63b2d2c7f3c07cf03b2d48511010118fe Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:08 -0600 Subject: [PATCH 03/19] Some small code cleanups to the flags file --- array_api_strict/_flags.py | 63 ++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index a26258e..1d50ba3 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -11,31 +11,34 @@ library will only support one particular configuration of these flags. """ +import functools import os -supported_versions = [ +supported_versions = ( "2021.12", "2022.12", -] +) -STANDARD_VERSION = "2022.12" +STANDARD_VERSION = default_version = "2022.12" DATA_DEPENDENT_SHAPES = True -all_extensions = [ +all_extensions = ( "linalg", "fft", -] +) extension_versions = { "linalg": "2021.12", "fft": "2022.12", } -ENABLED_EXTENSIONS = [ +ENABLED_EXTENSIONS = default_extensions = ( "linalg", "fft", -] +) + +# Public functions def set_array_api_strict_flags( *, @@ -136,8 +139,8 @@ def set_array_api_strict_flags( # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( supported_versions=supported_versions, - default_version=STANDARD_VERSION, - default_extensions=ENABLED_EXTENSIONS, + default_version=default_version, + default_extensions=default_extensions, ) def get_array_api_strict_flags(): @@ -160,7 +163,7 @@ def get_array_api_strict_flags(): >>> from array_api_strict import get_array_api_strict_flags >>> flags = get_array_api_strict_flags() >>> flags - {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ['linalg', 'fft']} + {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} See Also -------- @@ -181,6 +184,8 @@ def reset_array_api_strict_flags(): """ Reset the array-api-strict flags to their default values. + This will also reset any flags that were set by environment variables. + .. note:: This function is **not** part of the array API standard. It only exists @@ -201,9 +206,9 @@ def reset_array_api_strict_flags(): """ global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - STANDARD_VERSION = "2022.12" + STANDARD_VERSION = default_version DATA_DEPENDENT_SHAPES = True - ENABLED_EXTENSIONS = ["linalg", "fft"] + ENABLED_EXTENSIONS = default_extensions class ArrayApiStrictFlags: @@ -241,18 +246,22 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): set_array_api_strict_flags(**self.old_flags) -# Set the flags from the environment variables -if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: - set_array_api_strict_flags( - standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] - ) - -if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: - set_array_api_strict_flags( - data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" - ) - -if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: - set_array_api_strict_flags( - enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") - ) +# Private functions + +def set_flags_from_environment(): + if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + set_array_api_strict_flags( + standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + ) + + if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: + set_array_api_strict_flags( + data_dependent_shapes=os.environ["ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES"].lower() == "true" + ) + + if "ARRAY_API_STRICT_ENABLED_EXTENSIONS" in os.environ: + set_array_api_strict_flags( + enabled_extensions=os.environ["ARRAY_API_STRICT_ENABLED_EXTENSIONS"].split(",") + ) + +set_flags_from_environment() From 6a20e916827ae4c7c46a00e82c9dd8621c3285b8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:27 -0600 Subject: [PATCH 04/19] Add functionality for the data_dependent_shapes flag --- array_api_strict/_array_object.py | 19 ++++++++++++------- array_api_strict/_flags.py | 8 ++++++++ array_api_strict/_searching_functions.py | 2 ++ array_api_strict/_set_functions.py | 6 ++++++ 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 39808f0..e58767f 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -32,6 +32,7 @@ _result_type, _dtype_categories, ) +from ._flags import get_array_api_strict_flags from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types @@ -427,13 +428,17 @@ def _validate_index(self, key): "the Array API)" ) elif isinstance(i, Array): - if i.dtype in _boolean_dtypes and len(_key) != 1: - assert isinstance(key, tuple) # sanity check - raise IndexError( - f"Single-axes index {i} is a boolean array and " - f"{len(key)=}, but masking is only specified in the " - "Array API when the array is the sole index." - ) + if i.dtype in _boolean_dtypes: + if len(_key) != 1: + assert isinstance(key, tuple) # sanity check + raise IndexError( + f"Single-axes index {i} is a boolean array and " + f"{len(key)=}, but masking is only specified in the " + "Array API when the array is the sole index." + ) + if not get_array_api_strict_flags()['data_dependent_shapes']: + raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + elif i.dtype in _integer_dtypes and i.ndim != 0: raise IndexError( f"Single-axes index {i} is a non-zero-dimensional " diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 1d50ba3..33599ea 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -265,3 +265,11 @@ def set_flags_from_environment(): ) set_flags_from_environment() + +def requires_data_dependent_shapes(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not DATA_DEPENDENT_SHAPES: + raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index f4b2f56..9781531 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -2,6 +2,7 @@ from ._array_object import Array from ._dtypes import _result_type, _real_numeric_dtypes +from ._flags import requires_data_dependent_shapes from typing import Optional, Tuple @@ -30,6 +31,7 @@ def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) - return Array._new(np.asarray(np.argmin(x._array, axis=axis, keepdims=keepdims))) +@requires_data_dependent_shapes def nonzero(x: Array, /) -> Tuple[Array, ...]: """ Array API compatible wrapper for :py:func:`np.nonzero `. diff --git a/array_api_strict/_set_functions.py b/array_api_strict/_set_functions.py index 0b4132c..e6ca939 100644 --- a/array_api_strict/_set_functions.py +++ b/array_api_strict/_set_functions.py @@ -2,6 +2,8 @@ from ._array_object import Array +from ._flags import requires_data_dependent_shapes + from typing import NamedTuple import numpy as np @@ -35,6 +37,7 @@ class UniqueInverseResult(NamedTuple): inverse_indices: Array +@requires_data_dependent_shapes def unique_all(x: Array, /) -> UniqueAllResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -59,6 +62,7 @@ def unique_all(x: Array, /) -> UniqueAllResult: ) +@requires_data_dependent_shapes def unique_counts(x: Array, /) -> UniqueCountsResult: res = np.unique( x._array, @@ -71,6 +75,7 @@ def unique_counts(x: Array, /) -> UniqueCountsResult: return UniqueCountsResult(*[Array._new(i) for i in res]) +@requires_data_dependent_shapes def unique_inverse(x: Array, /) -> UniqueInverseResult: """ Array API compatible wrapper for :py:func:`np.unique `. @@ -90,6 +95,7 @@ def unique_inverse(x: Array, /) -> UniqueInverseResult: return UniqueInverseResult(Array._new(values), Array._new(inverse_indices)) +@requires_data_dependent_shapes def unique_values(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.unique `. From 4705b9fb25b075bb2c7695acc6dba88fc6b79fa9 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 9 Apr 2024 17:29:38 -0600 Subject: [PATCH 05/19] Add tests for flags --- array_api_strict/tests/test_flags.py | 78 ++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 array_api_strict/tests/test_flags.py diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py new file mode 100644 index 0000000..ede4b96 --- /dev/null +++ b/array_api_strict/tests/test_flags.py @@ -0,0 +1,78 @@ +from .._flags import (set_array_api_strict_flags, get_array_api_strict_flags, + reset_array_api_strict_flags) + +from .. import (asarray, unique_all, unique_counts, unique_inverse, + unique_values, nonzero) + +import pytest + +@pytest.fixture(autouse=True) +def reset_flags(): + reset_array_api_strict_flags() + yield + reset_array_api_strict_flags() + +def test_flags(): + # Test defaults + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } + + # Test setting flags + set_array_api_strict_flags(data_dependent_shapes=False) + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg', 'fft'), + } + set_array_api_strict_flags(enabled_extensions=('fft',)) + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2022.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('fft',), + } + # Make sure setting the version to 2021.12 disables fft + set_array_api_strict_flags(standard_version='2021.12') + flags = get_array_api_strict_flags() + assert flags == { + 'standard_version': '2021.12', + 'data_dependent_shapes': False, + 'enabled_extensions': ('linalg',), + } + + # Test setting flags with invalid values + pytest.raises(ValueError, lambda: + set_array_api_strict_flags(standard_version='2020.12')) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + enabled_extensions=('linalg', 'fft', 'invalid'))) + pytest.raises(ValueError, lambda: set_array_api_strict_flags( + standard_version='2021.12', + enabled_extensions=('linalg', 'fft'))) + + +def test_data_dependent_shapes(): + a = asarray([0, 0, 1, 2, 2]) + mask = asarray([True, False, True, False, True]) + + # Should not error + unique_all(a) + unique_counts(a) + unique_inverse(a) + unique_values(a) + nonzero(a) + a[mask] + # TODO: add repeat when it is implemented + + set_array_api_strict_flags(data_dependent_shapes=False) + + pytest.raises(RuntimeError, lambda: unique_all(a)) + pytest.raises(RuntimeError, lambda: unique_counts(a)) + pytest.raises(RuntimeError, lambda: unique_inverse(a)) + pytest.raises(RuntimeError, lambda: unique_values(a)) + pytest.raises(RuntimeError, lambda: nonzero(a)) + pytest.raises(RuntimeError, lambda: a[mask]) From 689a776aa6a581c321b6793985c9c135880d149d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 16:44:23 -0600 Subject: [PATCH 06/19] Respect the extension flag in linalg and fft This behavior still needs to be tested. This required moving the linalg functions that are also in the main namespace so that they can still work there even when the linalg extension is disabled. The way I've decided to implement this is that the functions will not raise an exception until they are called. It would probably be more convenient for users if they raised an attribute error, or if the extension namespace itself did, like it would in a real library without the given extension. But the implementation for this would be a lot more complicated and didn't really feel worth it to me. --- array_api_strict/__init__.py | 2 +- array_api_strict/_flags.py | 14 +++ array_api_strict/_linear_algebra_functions.py | 68 ++++++++++++ array_api_strict/fft.py | 15 +++ array_api_strict/linalg.py | 104 +++++++++--------- 5 files changed, 149 insertions(+), 54 deletions(-) create mode 100644 array_api_strict/_linear_algebra_functions.py diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 1fab8a2..31f0992 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -244,7 +244,7 @@ __all__ += ["linalg"] -from .linalg import matmul, tensordot, matrix_transpose, vecdot +from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"] diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 33599ea..bbe2c59 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -273,3 +273,17 @@ def wrapper(*args, **kwargs): raise RuntimeError(f"The function {func.__name__} requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict") return func(*args, **kwargs) return wrapper + +def requires_extension(extension): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if extension not in ENABLED_EXTENSIONS: + if extension == 'linalg' \ + and func.__name__ in ['matmul', 'tensordot', + 'matrix_transpose', 'vecdot']: + raise RuntimeError(f"The linalg extension has been disabled for array-api-strict. However, {func.__name__} is also present in the main array_api_strict namespace and may be used from there.") + raise RuntimeError(f"The function {func.__name__} requires the {extension} extension, but it has been disabled for array-api-strict") + return func(*args, **kwargs) + return wrapper + return decorator diff --git a/array_api_strict/_linear_algebra_functions.py b/array_api_strict/_linear_algebra_functions.py new file mode 100644 index 0000000..1ff08d4 --- /dev/null +++ b/array_api_strict/_linear_algebra_functions.py @@ -0,0 +1,68 @@ +""" +These functions are all also defined in the linalg extension, but we include +them here with wrappers in linalg so that the wrappers can be disabled if the +linalg extension is disabled in the flags. + +""" + +from __future__ import annotations + +from ._dtypes import _numeric_dtypes + +from ._array_object import Array + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._typing import Sequence, Tuple, Union + +import numpy.linalg +import numpy as np + +# Note: matmul is the numpy top-level namespace but not in np.linalg +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + # Note: the restriction to numeric dtypes only is different from + # np.matmul. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in matmul') + + return Array._new(np.matmul(x1._array, x2._array)) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg + +# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + # Note: the restriction to numeric dtypes only is different from + # np.tensordot. + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in tensordot') + + return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) + +# Note: this function is new in the array API spec. Unlike transpose, it only +# transposes the last two axes. +def matrix_transpose(x: Array, /) -> Array: + if x.ndim < 2: + raise ValueError("x must be at least 2-dimensional for matrix_transpose") + return Array._new(np.swapaxes(x._array, -1, -2)) + +# Note: vecdot is not in NumPy +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: + raise TypeError('Only numeric dtypes are allowed in vecdot') + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) diff --git a/array_api_strict/fft.py b/array_api_strict/fft.py index b50e9e3..7f427e5 100644 --- a/array_api_strict/fft.py +++ b/array_api_strict/fft.py @@ -15,9 +15,11 @@ ) from ._array_object import Array, CPU_DEVICE from ._data_type_functions import astype +from ._flags import requires_extension import numpy as np +@requires_extension('fft') def fft( x: Array, /, @@ -40,6 +42,7 @@ def fft( return astype(res, complex64) return res +@requires_extension('fft') def ifft( x: Array, /, @@ -62,6 +65,7 @@ def ifft( return astype(res, complex64) return res +@requires_extension('fft') def fftn( x: Array, /, @@ -84,6 +88,7 @@ def fftn( return astype(res, complex64) return res +@requires_extension('fft') def ifftn( x: Array, /, @@ -106,6 +111,7 @@ def ifftn( return astype(res, complex64) return res +@requires_extension('fft') def rfft( x: Array, /, @@ -128,6 +134,7 @@ def rfft( return astype(res, complex64) return res +@requires_extension('fft') def irfft( x: Array, /, @@ -150,6 +157,7 @@ def irfft( return astype(res, float32) return res +@requires_extension('fft') def rfftn( x: Array, /, @@ -172,6 +180,7 @@ def rfftn( return astype(res, complex64) return res +@requires_extension('fft') def irfftn( x: Array, /, @@ -194,6 +203,7 @@ def irfftn( return astype(res, float32) return res +@requires_extension('fft') def hfft( x: Array, /, @@ -216,6 +226,7 @@ def hfft( return astype(res, float32) return res +@requires_extension('fft') def ihfft( x: Array, /, @@ -238,6 +249,7 @@ def ihfft( return astype(res, complex64) return res +@requires_extension('fft') def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftfreq `. @@ -248,6 +260,7 @@ def fftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Ar raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.fftfreq(n, d=d)) +@requires_extension('fft') def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.rfftfreq `. @@ -258,6 +271,7 @@ def rfftfreq(n: int, /, *, d: float = 1.0, device: Optional[Device] = None) -> A raise ValueError(f"Unsupported device {device!r}") return Array._new(np.fft.rfftfreq(n, d=d)) +@requires_extension('fft') def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.fftshift `. @@ -268,6 +282,7 @@ def fftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: raise TypeError("Only floating-point dtypes are allowed in fftshift") return Array._new(np.fft.fftshift(x._array, axes=axes)) +@requires_extension('fft') def ifftshift(x: Array, /, *, axes: Union[int, Sequence[int]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.fft.ifftshift `. diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index 78e9ec4..e1998fa 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -12,6 +12,7 @@ from ._manipulation_functions import reshape from ._elementwise_functions import conj from ._array_object import Array +from ._flags import requires_extension try: from numpy._core.numeric import normalize_axis_tuple @@ -46,6 +47,7 @@ class SVDResult(NamedTuple): # Note: the inclusion of the upper keyword is different from # np.linalg.cholesky, which does not have it. +@requires_extension('linalg') def cholesky(x: Array, /, *, upper: bool = False) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.cholesky `. @@ -65,6 +67,7 @@ def cholesky(x: Array, /, *, upper: bool = False) -> Array: return Array._new(L) # Note: cross is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ Array API compatible wrapper for :py:func:`np.cross `. @@ -80,6 +83,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: raise ValueError('cross() dimension must equal 3') return Array._new(np.cross(x1._array, x2._array, axis=axis)) +@requires_extension('linalg') def det(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.det `. @@ -93,6 +97,7 @@ def det(x: Array, /) -> Array: return Array._new(np.linalg.det(x._array)) # Note: diagonal is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def diagonal(x: Array, /, *, offset: int = 0) -> Array: """ Array API compatible wrapper for :py:func:`np.diagonal `. @@ -103,7 +108,7 @@ def diagonal(x: Array, /, *, offset: int = 0) -> Array: # operates on the first two axes by default return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) - +@requires_extension('linalg') def eigh(x: Array, /) -> EighResult: """ Array API compatible wrapper for :py:func:`np.linalg.eigh `. @@ -120,6 +125,7 @@ def eigh(x: Array, /) -> EighResult: return EighResult(*map(Array._new, np.linalg.eigh(x._array))) +@requires_extension('linalg') def eigvalsh(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh `. @@ -133,6 +139,7 @@ def eigvalsh(x: Array, /) -> Array: return Array._new(np.linalg.eigvalsh(x._array)) +@requires_extension('linalg') def inv(x: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.inv `. @@ -146,28 +153,13 @@ def inv(x: Array, /) -> Array: return Array._new(np.linalg.inv(x._array)) - -# Note: matmul is the numpy top-level namespace but not in np.linalg -def matmul(x1: Array, x2: Array, /) -> Array: - """ - Array API compatible wrapper for :py:func:`np.matmul `. - - See its docstring for more information. - """ - # Note: the restriction to numeric dtypes only is different from - # np.matmul. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in matmul') - - return Array._new(np.matmul(x1._array, x2._array)) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point # literals. +@requires_extension('linalg') def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -182,6 +174,7 @@ def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) +@requires_extension('linalg') def matrix_power(x: Array, n: int, /) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_power `. @@ -197,6 +190,7 @@ def matrix_power(x: Array, n: int, /) -> Array: return Array._new(np.linalg.matrix_power(x._array, n)) # Note: the keyword argument name rtol is different from np.linalg.matrix_rank +@requires_extension('linalg') def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.matrix_rank `. @@ -219,14 +213,8 @@ def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> A return Array._new(np.count_nonzero(S > tol, axis=-1)) -# Note: this function is new in the array API spec. Unlike transpose, it only -# transposes the last two axes. -def matrix_transpose(x: Array, /) -> Array: - if x.ndim < 2: - raise ValueError("x must be at least 2-dimensional for matrix_transpose") - return Array._new(np.swapaxes(x._array, -1, -2)) - # Note: outer is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def outer(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.outer `. @@ -245,6 +233,7 @@ def outer(x1: Array, x2: Array, /) -> Array: return Array._new(np.outer(x1._array, x2._array)) # Note: the keyword argument name rtol is different from np.linalg.pinv +@requires_extension('linalg') def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.pinv `. @@ -262,6 +251,7 @@ def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: rtol = max(x.shape[-2:]) * finfo(x.dtype).eps return Array._new(np.linalg.pinv(x._array, rcond=rtol)) +@requires_extension('linalg') def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: """ Array API compatible wrapper for :py:func:`np.linalg.qr `. @@ -277,6 +267,7 @@ def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRRe # np.linalg.qr, which only returns a tuple. return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) +@requires_extension('linalg') def slogdet(x: Array, /) -> SlogdetResult: """ Array API compatible wrapper for :py:func:`np.linalg.slogdet `. @@ -335,6 +326,7 @@ def _solve(a, b): return wrap(r.astype(result_t, copy=False)) +@requires_extension('linalg') def solve(x1: Array, x2: Array, /) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.solve `. @@ -348,6 +340,7 @@ def solve(x1: Array, x2: Array, /) -> Array: return Array._new(_solve(x1._array, x2._array)) +@requires_extension('linalg') def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: """ Array API compatible wrapper for :py:func:`np.linalg.svd `. @@ -365,23 +358,14 @@ def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: # Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to # np.linalg.svd(compute_uv=False). +@requires_extension('linalg') def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in svdvals') return Array._new(np.linalg.svd(x._array, compute_uv=False)) -# Note: tensordot is the numpy top-level namespace but not in np.linalg - -# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like. -def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: - # Note: the restriction to numeric dtypes only is different from - # np.tensordot. - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in tensordot') - - return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) - # Note: trace is the numpy top-level namespace, not np.linalg +@requires_extension('linalg') def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: """ Array API compatible wrapper for :py:func:`np.trace `. @@ -404,29 +388,12 @@ def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Arr # operates on the first two axes by default return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) -# Note: vecdot is not in NumPy -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: - if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: - raise TypeError('Only numeric dtypes are allowed in vecdot') - ndim = max(x1.ndim, x2.ndim) - x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) - x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) - if x1_shape[axis] != x2_shape[axis]: - raise ValueError("x1 and x2 must have the same size along the given axis") - - x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) - x1_ = np.moveaxis(x1_, axis, -1) - x2_ = np.moveaxis(x2_, axis, -1) - - res = x1_[..., None, :] @ x2_[..., None] - return Array._new(res[..., 0, 0]) - - # Note: the name here is different from norm(). The array API norm is split # into matrix_norm and vector_norm(). # The type for ord should be Optional[Union[int, float, Literal[np.inf, # -np.inf]]] but Literal does not support floating-point literals. +@requires_extension('linalg') def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: """ Array API compatible wrapper for :py:func:`np.linalg.norm `. @@ -472,4 +439,35 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No return res +# These functions are also in the main namespace. We define them here as +# wrappers so that they can still be disabled when the linalg extension is +# disabled without disabling the versions in the main namespace. + +# Note: matmul is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def matmul(x1: Array, x2: Array, /) -> Array: + """ + Array API compatible wrapper for :py:func:`np.matmul `. + + See its docstring for more information. + """ + from ._linear_algebra_functions import matmul + return matmul(x1, x2) + +# Note: tensordot is the numpy top-level namespace but not in np.linalg +@requires_extension('linalg') +def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: + from ._linear_algebra_functions import tensordot + return tensordot(x1, x2, axes=axes) + +@requires_extension('linalg') +def matrix_transpose(x: Array, /) -> Array: + from ._linear_algebra_functions import matrix_transpose + return matrix_transpose(x) + +@requires_extension('linalg') +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: + from ._linear_algebra_functions import vecdot + return vecdot(x1, x2, axis=axis) + __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] From 78d368d6ba3f3adb2f138b0a0d8cd8294db61f0a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 17:06:17 -0600 Subject: [PATCH 07/19] Add tests for disabling linalg and fft extensions --- array_api_strict/tests/test_flags.py | 91 ++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index ede4b96..d3d957f 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -4,6 +4,8 @@ from .. import (asarray, unique_all, unique_counts, unique_inverse, unique_values, nonzero) +import array_api_strict as xp + import pytest @pytest.fixture(autouse=True) @@ -76,3 +78,92 @@ def test_data_dependent_shapes(): pytest.raises(RuntimeError, lambda: unique_values(a)) pytest.raises(RuntimeError, lambda: nonzero(a)) pytest.raises(RuntimeError, lambda: a[mask]) + +linalg_examples = { + 'cholesky': lambda: xp.linalg.cholesky(xp.eye(3)), + 'cross': lambda: xp.linalg.cross(xp.asarray([1, 0, 0]), xp.asarray([0, 1, 0])), + 'det': lambda: xp.linalg.det(xp.eye(3)), + 'diagonal': lambda: xp.linalg.diagonal(xp.eye(3)), + 'eigh': lambda: xp.linalg.eigh(xp.eye(3)), + 'eigvalsh': lambda: xp.linalg.eigvalsh(xp.eye(3)), + 'inv': lambda: xp.linalg.inv(xp.eye(3)), + 'matmul': lambda: xp.linalg.matmul(xp.eye(3), xp.eye(3)), + 'matrix_norm': lambda: xp.linalg.matrix_norm(xp.eye(3)), + 'matrix_power': lambda: xp.linalg.matrix_power(xp.eye(3), 2), + 'matrix_rank': lambda: xp.linalg.matrix_rank(xp.eye(3)), + 'matrix_transpose': lambda: xp.linalg.matrix_transpose(xp.eye(3)), + 'outer': lambda: xp.linalg.outer(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'pinv': lambda: xp.linalg.pinv(xp.eye(3)), + 'qr': lambda: xp.linalg.qr(xp.eye(3)), + 'slogdet': lambda: xp.linalg.slogdet(xp.eye(3)), + 'solve': lambda: xp.linalg.solve(xp.eye(3), xp.eye(3)), + 'svd': lambda: xp.linalg.svd(xp.eye(3)), + 'svdvals': lambda: xp.linalg.svdvals(xp.eye(3)), + 'tensordot': lambda: xp.linalg.tensordot(xp.eye(3), xp.eye(3)), + 'trace': lambda: xp.linalg.trace(xp.eye(3)), + 'vecdot': lambda: xp.linalg.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), + 'vector_norm': lambda: xp.linalg.vector_norm(xp.asarray([1., 2., 3.])), +} + +assert set(linalg_examples) == set(xp.linalg.__all__) + +linalg_main_namespace_examples = { + 'matmul': lambda: xp.matmul(xp.eye(3), xp.eye(3)), + 'matrix_transpose': lambda: xp.matrix_transpose(xp.eye(3)), + 'tensordot': lambda: xp.tensordot(xp.eye(3), xp.eye(3)), + 'vecdot': lambda: xp.vecdot(xp.asarray([1, 2, 3]), xp.asarray([4, 5, 6])), +} + +assert set(linalg_main_namespace_examples) == set(xp.__all__) & set(xp.linalg.__all__) + +@pytest.mark.parametrize('func_name', linalg_examples.keys()) +def test_linalg(func_name): + func = linalg_examples[func_name] + if func_name in linalg_main_namespace_examples: + main_namespace_func = linalg_main_namespace_examples[func_name] + else: + main_namespace_func = lambda: None + + # First make sure the example actually works + func() + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + main_namespace_func() + + set_array_api_strict_flags(enabled_extensions=('linalg',)) + func() + main_namespace_func() + +fft_examples = { + 'fft': lambda: xp.fft.fft(xp.asarray([0j, 1j, 0j, 0j])), + 'ifft': lambda: xp.fft.ifft(xp.asarray([0j, 1j, 0j, 0j])), + 'fftn': lambda: xp.fft.fftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'ifftn': lambda: xp.fft.ifftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'rfft': lambda: xp.fft.rfft(xp.asarray([0., 1., 0., 0.])), + 'irfft': lambda: xp.fft.irfft(xp.asarray([0j, 1j, 0j, 0j])), + 'rfftn': lambda: xp.fft.rfftn(xp.asarray([[0., 1.], [0., 0.]])), + 'irfftn': lambda: xp.fft.irfftn(xp.asarray([[0j, 1j], [0j, 0j]])), + 'hfft': lambda: xp.fft.hfft(xp.asarray([0j, 1j, 0j, 0j])), + 'ihfft': lambda: xp.fft.ihfft(xp.asarray([0., 1., 0., 0.])), + 'fftfreq': lambda: xp.fft.fftfreq(4), + 'rfftfreq': lambda: xp.fft.rfftfreq(4), + 'fftshift': lambda: xp.fft.fftshift(xp.asarray([0j, 1j, 0j, 0j])), + 'ifftshift': lambda: xp.fft.ifftshift(xp.asarray([0j, 1j, 0j, 0j])), +} + +assert set(fft_examples) == set(xp.fft.__all__) + +@pytest.mark.parametrize('func_name', fft_examples.keys()) +def test_fft(func_name): + func = fft_examples[func_name] + + # First make sure the example actually works + func() + + set_array_api_strict_flags(enabled_extensions=()) + pytest.raises(RuntimeError, func) + + set_array_api_strict_flags(enabled_extensions=('fft',)) + func() From c32a452864d10afbeebac31cf9d30858289a70fa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Fri, 12 Apr 2024 23:37:06 -0600 Subject: [PATCH 08/19] Add docstring to __array_namespace__ --- array_api_strict/_array_object.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index e58767f..9150381 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -487,6 +487,20 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array: def __array_namespace__( self: Array, /, *, api_version: Optional[str] = None ) -> types.ModuleType: + """ + Return the array_api_strict namespace corresponding to api_version. + + The default API version is '2022.12'. Note that '2021.12' is supported, + but currently identical to '2022.12'. + + For array_api_strict, calling this function with api_version will set + the API version for the array_api_strict module globally. This can + also be achieved with the + {func}`array_api_strict.set_array_api_strict_flags` function. If you + want some way to only set the version locally, use the + {class}`array_api_strict.ArrayApiStrictFlags` context manager. + + """ if api_version is not None and api_version not in ["2021.12", "2022.12"]: raise ValueError(f"Unrecognized array API version: {api_version!r}") if api_version == "2021.12": From 632900f92082f7701cb2cdd79b0f99dcf336649c Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:52:57 -0600 Subject: [PATCH 09/19] Remove duplicate sentence --- array_api_strict/_flags.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index bbe2c59..13ba89f 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -62,12 +62,11 @@ def set_array_api_strict_flags( ``{default_version!r}``. - `data_dependent_shapes`: Whether data-dependent shapes are enabled in - array-api-strict. This flag is enabled by default. Array libraries that - use computation graphs may not be able to support functions whose output - shapes depend on the input data. + array-api-strict. - This flag is enabled by default. Array libraries that use computation graphs may not be able to support - functions whose output shapes depend on the input data. + This flag is enabled by default. Array libraries that use computation + graphs may not be able to support functions whose output shapes depend + on the input data. The functions that make use of data-dependent shapes, and are therefore disabled by setting this flag to False are From befd28c198ad010e2036221429c44ef3265085e8 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:55:52 -0600 Subject: [PATCH 10/19] Set the api version flag in __array_namespace__ --- array_api_strict/_array_object.py | 5 ++--- array_api_strict/_flags.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 9150381..a2d68a5 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -32,7 +32,7 @@ _result_type, _dtype_categories, ) -from ._flags import get_array_api_strict_flags +from ._flags import get_array_api_strict_flags, set_array_api_strict_flags from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex import types @@ -501,8 +501,7 @@ def __array_namespace__( {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ - if api_version is not None and api_version not in ["2021.12", "2022.12"]: - raise ValueError(f"Unrecognized array API version: {api_version!r}") + set_array_api_strict_flags(standard_version=api_version) if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 13ba89f..eb76289 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -116,7 +116,7 @@ def set_array_api_strict_flags( if standard_version is not None: if standard_version not in supported_versions: - raise ValueError(f"Unsupported standard version {standard_version}") + raise ValueError(f"Unsupported standard version {standard_version!r}") STANDARD_VERSION = standard_version if data_dependent_shapes is not None: From 319799ee31e0c3778f27c79ca299b0bddab11317 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 12:57:32 -0600 Subject: [PATCH 11/19] Rename the "standard_version" flag to "api_version" This matches the name used in __array_namespace__ --- array_api_strict/_array_object.py | 2 +- array_api_strict/_flags.py | 40 ++++++++++++++-------------- array_api_strict/tests/test_flags.py | 14 +++++----- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index a2d68a5..28d0eb9 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -501,7 +501,7 @@ def __array_namespace__( {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ - set_array_api_strict_flags(standard_version=api_version) + set_array_api_strict_flags(api_version=api_version) if api_version == "2021.12": warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index eb76289..d02476d 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -19,7 +19,7 @@ "2022.12", ) -STANDARD_VERSION = default_version = "2022.12" +API_VERSION = default_version = "2022.12" DATA_DEPENDENT_SHAPES = True @@ -42,7 +42,7 @@ def set_array_api_strict_flags( *, - standard_version=None, + api_version=None, data_dependent_shapes=None, enabled_extensions=None, ): @@ -57,7 +57,7 @@ def set_array_api_strict_flags( This function is **not** part of the array API standard. It only exists in array-api-strict. - - `standard_version`: The version of the standard to use. Supported + - `api_version`: The version of the standard to use. Supported versions are: ``{supported_versions}``. The default version number is ``{default_version!r}``. @@ -88,7 +88,7 @@ def set_array_api_strict_flags( The default values of the flags can also be changed by setting environment variables: - - ``ARRAY_API_STRICT_STANDARD_VERSION``: A string representing the version number. + - ``ARRAY_API_STRICT_API_VERSION``: A string representing the version number. - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of extensions to enable. @@ -98,7 +98,7 @@ def set_array_api_strict_flags( >>> from array_api_strict import set_array_api_strict_flags >>> # Set the standard version to 2021.12 - >>> set_array_api_strict_flags(standard_version="2021.12") + >>> set_array_api_strict_flags(api_version="2021.12") >>> # Disable data-dependent shapes >>> set_array_api_strict_flags(data_dependent_shapes=False) >>> # Enable only the linalg extension (disable the fft extension) @@ -112,12 +112,12 @@ def set_array_api_strict_flags( ArrayApiStrictFlags: A context manager to temporarily set the flags. """ - global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - if standard_version is not None: - if standard_version not in supported_versions: - raise ValueError(f"Unsupported standard version {standard_version!r}") - STANDARD_VERSION = standard_version + if api_version is not None: + if api_version not in supported_versions: + raise ValueError(f"Unsupported standard version {api_version!r}") + API_VERSION = api_version if data_dependent_shapes is not None: DATA_DEPENDENT_SHAPES = data_dependent_shapes @@ -126,14 +126,14 @@ def set_array_api_strict_flags( for extension in enabled_extensions: if extension not in all_extensions: raise ValueError(f"Unsupported extension {extension}") - if extension_versions[extension] > STANDARD_VERSION: + if extension_versions[extension] > API_VERSION: raise ValueError( f"Extension {extension} requires standard version " f"{extension_versions[extension]} or later" ) ENABLED_EXTENSIONS = tuple(enabled_extensions) else: - ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= STANDARD_VERSION]) + ENABLED_EXTENSIONS = tuple([ext for ext in all_extensions if extension_versions[ext] <= API_VERSION]) # We have to do this separately or it won't get added as the docstring set_array_api_strict_flags.__doc__ = set_array_api_strict_flags.__doc__.format( @@ -162,7 +162,7 @@ def get_array_api_strict_flags(): >>> from array_api_strict import get_array_api_strict_flags >>> flags = get_array_api_strict_flags() >>> flags - {'standard_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} + {'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft')} See Also -------- @@ -173,7 +173,7 @@ def get_array_api_strict_flags(): """ return { - "standard_version": STANDARD_VERSION, + "api_version": API_VERSION, "data_dependent_shapes": DATA_DEPENDENT_SHAPES, "enabled_extensions": ENABLED_EXTENSIONS, } @@ -204,8 +204,8 @@ def reset_array_api_strict_flags(): ArrayApiStrictFlags: A context manager to temporarily set the flags. """ - global STANDARD_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS - STANDARD_VERSION = default_version + global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS + API_VERSION = default_version DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions @@ -230,10 +230,10 @@ class ArrayApiStrictFlags: reset_array_api_strict_flags """ - def __init__(self, *, standard_version=None, data_dependent_shapes=None, + def __init__(self, *, api_version=None, data_dependent_shapes=None, enabled_extensions=None): self.kwargs = { - "standard_version": standard_version, + "api_version": api_version, "data_dependent_shapes": data_dependent_shapes, "enabled_extensions": enabled_extensions, } @@ -248,9 +248,9 @@ def __exit__(self, exc_type, exc_value, traceback): # Private functions def set_flags_from_environment(): - if "ARRAY_API_STRICT_STANDARD_VERSION" in os.environ: + if "ARRAY_API_STRICT_API_VERSION" in os.environ: set_array_api_strict_flags( - standard_version=os.environ["ARRAY_API_STRICT_STANDARD_VERSION"] + api_version=os.environ["ARRAY_API_STRICT_API_VERSION"] ) if "ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES" in os.environ: diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index d3d957f..99037a5 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -18,7 +18,7 @@ def test_flags(): # Test defaults flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': True, 'enabled_extensions': ('linalg', 'fft'), } @@ -27,33 +27,33 @@ def test_flags(): set_array_api_strict_flags(data_dependent_shapes=False) flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': False, 'enabled_extensions': ('linalg', 'fft'), } set_array_api_strict_flags(enabled_extensions=('fft',)) flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2022.12', + 'api_version': '2022.12', 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } # Make sure setting the version to 2021.12 disables fft - set_array_api_strict_flags(standard_version='2021.12') + set_array_api_strict_flags(api_version='2021.12') flags = get_array_api_strict_flags() assert flags == { - 'standard_version': '2021.12', + 'api_version': '2021.12', 'data_dependent_shapes': False, 'enabled_extensions': ('linalg',), } # Test setting flags with invalid values pytest.raises(ValueError, lambda: - set_array_api_strict_flags(standard_version='2020.12')) + set_array_api_strict_flags(api_version='2020.12')) pytest.raises(ValueError, lambda: set_array_api_strict_flags( enabled_extensions=('linalg', 'fft', 'invalid'))) pytest.raises(ValueError, lambda: set_array_api_strict_flags( - standard_version='2021.12', + api_version='2021.12', enabled_extensions=('linalg', 'fft'))) From b4bfb8d14e6aaac1efb5e2cf0efb02cf10462614 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:13:25 -0600 Subject: [PATCH 12/19] Move the reset_flags fixture to be global to all the tests --- array_api_strict/tests/conftest.py | 9 +++++++++ array_api_strict/tests/test_flags.py | 6 ------ 2 files changed, 9 insertions(+), 6 deletions(-) create mode 100644 array_api_strict/tests/conftest.py diff --git a/array_api_strict/tests/conftest.py b/array_api_strict/tests/conftest.py new file mode 100644 index 0000000..5000d5d --- /dev/null +++ b/array_api_strict/tests/conftest.py @@ -0,0 +1,9 @@ +from .._flags import reset_array_api_strict_flags + +import pytest + +@pytest.fixture(autouse=True) +def reset_flags(): + reset_array_api_strict_flags() + yield + reset_array_api_strict_flags() diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 99037a5..dff216a 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -8,12 +8,6 @@ import pytest -@pytest.fixture(autouse=True) -def reset_flags(): - reset_array_api_strict_flags() - yield - reset_array_api_strict_flags() - def test_flags(): # Test defaults flags = get_array_api_strict_flags() From 0d758ebac45c6a3e847d18e6cce31491fbd127db Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:13:50 -0600 Subject: [PATCH 13/19] Test reset_array_api_strict_flags() --- array_api_strict/tests/test_flags.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index dff216a..5e5e171 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -50,6 +50,18 @@ def test_flags(): api_version='2021.12', enabled_extensions=('linalg', 'fft'))) + # Test resetting flags + set_array_api_strict_flags( + api_version='2021.12', + data_dependent_shapes=False, + enabled_extensions=()) + reset_array_api_strict_flags() + flags = get_array_api_strict_flags() + assert flags == { + 'api_version': '2022.12', + 'data_dependent_shapes': True, + 'enabled_extensions': ('linalg', 'fft'), + } def test_data_dependent_shapes(): a = asarray([0, 0, 1, 2, 2]) From 0ba1267d89f5f5cc37fd673606c8902aae526860 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 16 Apr 2024 13:14:09 -0600 Subject: [PATCH 14/19] Set __array_api_version__ with the api_version flag --- array_api_strict/__init__.py | 6 +++++- array_api_strict/_array_object.py | 2 +- array_api_strict/_flags.py | 5 ++++- array_api_strict/tests/test_array_object.py | 8 ++++++++ array_api_strict/tests/test_flags.py | 8 ++++++++ 5 files changed, 26 insertions(+), 3 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index 31f0992..b323a65 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -16,7 +16,11 @@ """ -__array_api_version__ = "2022.12" +# Warning: __array_api_version__ could change globally with +# set_array_api_strict_flags(). This should always be accessed as an +# attribute, like xp.__array_api_version__, or using +# array_api_strict.get_array_api_strict_flags()['api_version']. +from ._flags import API_VERSION as __array_api_version__ __all__ = ["__array_api_version__"] diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index 28d0eb9..ff9b8f8 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -497,7 +497,7 @@ def __array_namespace__( the API version for the array_api_strict module globally. This can also be achieved with the {func}`array_api_strict.set_array_api_strict_flags` function. If you - want some way to only set the version locally, use the + want to only set the version locally, use the {class}`array_api_strict.ArrayApiStrictFlags` context manager. """ diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index d02476d..80de965 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -14,6 +14,8 @@ import functools import os +import array_api_strict + supported_versions = ( "2021.12", "2022.12", @@ -37,7 +39,6 @@ "linalg", "fft", ) - # Public functions def set_array_api_strict_flags( @@ -118,6 +119,7 @@ def set_array_api_strict_flags( if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") API_VERSION = api_version + array_api_strict.__array_api_version__ = API_VERSION if data_dependent_shapes is not None: DATA_DEPENDENT_SHAPES = data_dependent_shapes @@ -206,6 +208,7 @@ def reset_array_api_strict_flags(): """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS API_VERSION = default_version + array_api_strict.__array_api_version__ = API_VERSION DATA_DEPENDENT_SHAPES = True ENABLED_EXTENSIONS = default_extensions diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index a0b6132..bae0553 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -402,9 +402,17 @@ def test_array_keys_use_private_array(): def test_array_namespace(): a = ones((3, 3)) assert a.__array_namespace__() == array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version=None) is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + assert a.__array_namespace__(api_version="2022.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2022.12" + with pytest.warns(UserWarning): assert a.__array_namespace__(api_version="2021.12") is array_api_strict + assert array_api_strict.__array_api_version__ == "2021.12" + pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12")) diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index 5e5e171..f9d8ad6 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -63,6 +63,14 @@ def test_flags(): 'enabled_extensions': ('linalg', 'fft'), } +def test_api_version(): + # Test defaults + assert xp.__array_api_version__ == '2022.12' + + # Test setting the version + set_array_api_strict_flags(api_version='2021.12') + assert xp.__array_api_version__ == '2021.12' + def test_data_dependent_shapes(): a = asarray([0, 0, 1, 2, 2]) mask = asarray([True, False, True, False, True]) From 30baeb7daa5197d0babb84cbed15cb43402cb34b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:16:44 -0600 Subject: [PATCH 15/19] Move warning about 2021.12 to set_array_api_strict_flags() --- array_api_strict/_array_object.py | 3 --- array_api_strict/_flags.py | 6 ++++++ array_api_strict/tests/test_flags.py | 17 +++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index ff9b8f8..2b9155a 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -17,7 +17,6 @@ import operator from enum import IntEnum -import warnings from ._creation_functions import asarray from ._dtypes import ( @@ -502,8 +501,6 @@ def __array_namespace__( """ set_array_api_strict_flags(api_version=api_version) - if api_version == "2021.12": - warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") import array_api_strict return array_api_strict diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 80de965..2114620 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -13,6 +13,7 @@ import functools import os +import warnings import array_api_strict @@ -62,6 +63,9 @@ def set_array_api_strict_flags( versions are: ``{supported_versions}``. The default version number is ``{default_version!r}``. + Note that 2021.12 is supported, but currently gives the same thing as + 2022.12 (except that the fft extension will be disabled). + - `data_dependent_shapes`: Whether data-dependent shapes are enabled in array-api-strict. @@ -118,6 +122,8 @@ def set_array_api_strict_flags( if api_version is not None: if api_version not in supported_versions: raise ValueError(f"Unsupported standard version {api_version!r}") + if api_version == "2021.12": + warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12") API_VERSION = api_version array_api_strict.__array_api_version__ = API_VERSION diff --git a/array_api_strict/tests/test_flags.py b/array_api_strict/tests/test_flags.py index f9d8ad6..303c930 100644 --- a/array_api_strict/tests/test_flags.py +++ b/array_api_strict/tests/test_flags.py @@ -32,8 +32,12 @@ def test_flags(): 'data_dependent_shapes': False, 'enabled_extensions': ('fft',), } - # Make sure setting the version to 2021.12 disables fft - set_array_api_strict_flags(api_version='2021.12') + # Make sure setting the version to 2021.12 disables fft and issues a + # warning. + with pytest.warns(UserWarning) as record: + set_array_api_strict_flags(api_version='2021.12') + assert len(record) == 1 + assert '2021.12' in str(record[0].message) flags = get_array_api_strict_flags() assert flags == { 'api_version': '2021.12', @@ -51,10 +55,11 @@ def test_flags(): enabled_extensions=('linalg', 'fft'))) # Test resetting flags - set_array_api_strict_flags( - api_version='2021.12', - data_dependent_shapes=False, - enabled_extensions=()) + with pytest.warns(UserWarning): + set_array_api_strict_flags( + api_version='2021.12', + data_dependent_shapes=False, + enabled_extensions=()) reset_array_api_strict_flags() flags = get_array_api_strict_flags() assert flags == { From 22352d2f29fc9581959860d6c04026c4d88ba86f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:35:13 -0600 Subject: [PATCH 16/19] Update some flags documentation --- array_api_strict/_flags.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index 2114620..e0344d8 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -1,14 +1,15 @@ """ -This file defines flags for that allow array-api-strict to be used in -different "modes". These modes include +These functions configure global flags that allow array-api-strict to be +used in different "modes". These modes include - Changing to different supported versions of the standard. - Enabling or disabling different optional behaviors (such as data-dependent shapes). - Enabling or disabling different optional extensions. -Nothing in this file is part of the standard itself. A typical array API +None of these functions are part of the standard itself. A typical array API library will only support one particular configuration of these flags. + """ import functools @@ -112,8 +113,8 @@ def set_array_api_strict_flags( See Also -------- - get_array_api_strict_flags - reset_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + reset_array_api_strict_flags: Reset the flags to their default values. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -175,8 +176,8 @@ def get_array_api_strict_flags(): See Also -------- - set_array_api_strict_flags - reset_array_api_strict_flags + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -207,8 +208,8 @@ def reset_array_api_strict_flags(): See Also -------- - set_array_api_strict_flags - get_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. ArrayApiStrictFlags: A context manager to temporarily set the flags. """ @@ -234,9 +235,9 @@ class ArrayApiStrictFlags: See Also -------- - set_array_api_strict_flags - get_array_api_strict_flags - reset_array_api_strict_flags + get_array_api_strict_flags: Get the current values of flags. + set_array_api_strict_flags: Set one or more flags to a given value. + reset_array_api_strict_flags: Reset the flags to their default values. """ def __init__(self, *, api_version=None, data_dependent_shapes=None, From 654865dd1ca57beddfb931c2cfced2169406cfe0 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 16:36:23 -0600 Subject: [PATCH 17/19] Rename ArrayApiStrictFlags to ArrayAPIStrictFlags --- array_api_strict/__init__.py | 4 ++-- array_api_strict/_flags.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_strict/__init__.py b/array_api_strict/__init__.py index b323a65..3f418d8 100644 --- a/array_api_strict/__init__.py +++ b/array_api_strict/__init__.py @@ -294,10 +294,10 @@ set_array_api_strict_flags, get_array_api_strict_flags, reset_array_api_strict_flags, - ArrayApiStrictFlags, + ArrayAPIStrictFlags, ) -__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayApiStrictFlags'] +__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags'] from . import _version __version__ = _version.get_versions()['version'] diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index e0344d8..a07393a 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -115,7 +115,7 @@ def set_array_api_strict_flags( get_array_api_strict_flags: Get the current values of flags. reset_array_api_strict_flags: Reset the flags to their default values. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS @@ -178,7 +178,7 @@ def get_array_api_strict_flags(): set_array_api_strict_flags: Set one or more flags to a given value. reset_array_api_strict_flags: Reset the flags to their default values. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ return { @@ -210,7 +210,7 @@ def reset_array_api_strict_flags(): get_array_api_strict_flags: Get the current values of flags. set_array_api_strict_flags: Set one or more flags to a given value. - ArrayApiStrictFlags: A context manager to temporarily set the flags. + ArrayAPIStrictFlags: A context manager to temporarily set the flags. """ global API_VERSION, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS @@ -220,7 +220,7 @@ def reset_array_api_strict_flags(): ENABLED_EXTENSIONS = default_extensions -class ArrayApiStrictFlags: +class ArrayAPIStrictFlags: """ A context manager to temporarily set the array-api-strict flags. From 0011d23ade7695d043fda38c08f571a43b928f79 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 17:00:20 -0600 Subject: [PATCH 18/19] Add flags functions to Sphinx documentation --- array_api_strict/_flags.py | 20 ++++++++++--------- docs/api.rst | 39 ++++++++++++++++++++++++++++++++++++++ docs/changelog.md | 2 +- docs/conf.py | 4 ++-- docs/index.md | 2 ++ 5 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 docs/api.rst diff --git a/array_api_strict/_flags.py b/array_api_strict/_flags.py index a07393a..6cc503a 100644 --- a/array_api_strict/_flags.py +++ b/array_api_strict/_flags.py @@ -91,22 +91,20 @@ def set_array_api_strict_flags( array-api-strict. The default is ``{default_extensions}``. Note that some extensions require a minimum version of the standard. - The default values of the flags can also be changed by setting environment - variables: - - - ``ARRAY_API_STRICT_API_VERSION``: A string representing the version number. - - ``ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES``: "True" or "False". - - ``ARRAY_API_STRICT_ENABLED_EXTENSIONS``: A comma separated list of - extensions to enable. + The flags can also be changed by setting :ref:`environment variables + `. Examples -------- >>> from array_api_strict import set_array_api_strict_flags + >>> # Set the standard version to 2021.12 >>> set_array_api_strict_flags(api_version="2021.12") + >>> # Disable data-dependent shapes >>> set_array_api_strict_flags(data_dependent_shapes=False) + >>> # Enable only the linalg extension (disable the fft extension) >>> set_array_api_strict_flags(enabled_extensions=["linalg"]) @@ -192,13 +190,17 @@ def reset_array_api_strict_flags(): """ Reset the array-api-strict flags to their default values. - This will also reset any flags that were set by environment variables. + This will also reset any flags that were set by :ref:`environment + variables ` back to their default values. .. note:: This function is **not** part of the array API standard. It only exists in array-api-strict. + See :func:`set_array_api_strict_flags` for a list of flags and their + default values. + Examples -------- @@ -229,7 +231,7 @@ class ArrayAPIStrictFlags: This class is **not** part of the array API standard. It only exists in array-api-strict. - See :func:`~.array_api_strict.set_array_api_strict_flags` for a + See :func:`set_array_api_strict_flags` for a description of the available flags. See Also diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..0827982 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,39 @@ +API Reference +============= + +.. automodule:: array_api_strict + +Array API Strict Flags +---------------------- + +.. automodule:: array_api_strict._flags + +.. currentmodule:: array_api_strict + +.. autofunction:: get_array_api_strict_flags +.. autofunction:: set_array_api_strict_flags +.. autofunction:: reset_array_api_strict_flags +.. autoclass:: ArrayAPIStrictFlags + +.. _environment-variables: + +Environment Variables +~~~~~~~~~~~~~~~~~~~~~ + +Flags can also be set with environment variables. +:func:`set_array_api_strict_flags` will override the values set by environment +variables. Note that the environment variables will only change the defaults +used by array-api-strict initially. They will not change the defaults used by +:func:`reset_array_api_strict_flags`. + +.. envvar:: ARRAY_API_STRICT_API_VERSION + + A string representing the version number. + +.. envvar:: ARRAY_API_STRICT_DATA_DEPENDENT_SHAPES + + "True" or "False" to enable or disable data dependent shapes. + +.. envvar:: ARRAY_API_STRICT_ENABLED_EXTENSIONS + + A comma separated list of extensions to enable. diff --git a/docs/changelog.md b/docs/changelog.md index 8f1c203..04c383d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -20,7 +20,7 @@ This is the first release of `array_api_strict`. It is extracted from `numpy.array_api`, which was included as an experimental submodule in NumPy versions prior to 2.0. Note that the commit history in this repository is -extracted from the git history of numpy/array_api/ (see the [README](README.md)). +extracted from the git history of numpy/array_api/ (see [](numpy.array_api)). Additionally, the following changes are new to `array_api_strict` from `numpy.array_api` in NumPy 1.26 (the last NumPy feature release to include diff --git a/docs/conf.py b/docs/conf.py index c068b06..e4c66d7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,8 +22,8 @@ extensions = [ 'myst_parser', - # 'sphinx.ext.autodoc', - # 'sphinx.ext.napoleon', + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', # 'sphinx.ext.intersphinx', 'sphinx_copybutton', ] diff --git a/docs/index.md b/docs/index.md index 307a9c2..6e84efa 100644 --- a/docs/index.md +++ b/docs/index.md @@ -183,6 +183,7 @@ issue, but this hasn't necessarily been tested thoroughly. API standard. [Support for 2023.12 is planned](https://github.com/data-apis/array-api-strict/issues/25). +(numpy.array_api)= ## Relationship to `numpy.array_api` Previously this implementation was available as `numpy.array_api`, but it was @@ -201,5 +202,6 @@ git_filter_repo.py --path numpy/array_api/ --path-rename numpy/array_api:array_a :titlesonly: :hidden: +api.rst changelog.md ``` From f92b4973251f1a90d56a6be5481cd6a91fd70413 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 18 Apr 2024 17:02:54 -0600 Subject: [PATCH 19/19] Add a note about docs for standard functions --- docs/api.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index 0827982..e703a63 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -37,3 +37,12 @@ used by array-api-strict initially. They will not change the defaults used by .. envvar:: ARRAY_API_STRICT_ENABLED_EXTENSIONS A comma separated list of extensions to enable. + +Array API Functions +-------------------- + +All functions and methods in +the array API standard are implemented in array-api-strict. See the `Array API +Standard +`__ for +full documentation for each function.