Skip to content

array-api-strict flags #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

"""

__array_api_version__ = "2022.12"
# Warning: __array_api_version__ could change globally with
# set_array_api_strict_flags(). This should always be accessed as an
# attribute, like xp.__array_api_version__, or using
# array_api_strict.get_array_api_strict_flags()['api_version'].
from ._flags import API_VERSION as __array_api_version__

__all__ = ["__array_api_version__"]

Expand Down Expand Up @@ -244,7 +248,7 @@

__all__ += ["linalg"]

from .linalg import matmul, tensordot, matrix_transpose, vecdot
from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot

__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]

Expand Down Expand Up @@ -284,6 +288,17 @@

__all__ += ["all", "any"]

# Helper functions that are not part of the standard

from ._flags import (
set_array_api_strict_flags,
get_array_api_strict_flags,
reset_array_api_strict_flags,
ArrayAPIStrictFlags,
)

__all__ += ['set_array_api_strict_flags', 'get_array_api_strict_flags', 'reset_array_api_strict_flags', 'ArrayAPIStrictFlags']

from . import _version
__version__ = _version.get_versions()['version']
del _version
39 changes: 27 additions & 12 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import operator
from enum import IntEnum
import warnings

from ._creation_functions import asarray
from ._dtypes import (
Expand All @@ -32,6 +31,7 @@
_result_type,
_dtype_categories,
)
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags

from typing import TYPE_CHECKING, Optional, Tuple, Union, Any, SupportsIndex
import types
Expand Down Expand Up @@ -427,13 +427,17 @@ def _validate_index(self, key):
"the Array API)"
)
elif isinstance(i, Array):
if i.dtype in _boolean_dtypes and len(_key) != 1:
assert isinstance(key, tuple) # sanity check
raise IndexError(
f"Single-axes index {i} is a boolean array and "
f"{len(key)=}, but masking is only specified in the "
"Array API when the array is the sole index."
)
if i.dtype in _boolean_dtypes:
if len(_key) != 1:
assert isinstance(key, tuple) # sanity check
raise IndexError(
f"Single-axes index {i} is a boolean array and "
f"{len(key)=}, but masking is only specified in the "
"Array API when the array is the sole index."
)
if not get_array_api_strict_flags()['data_dependent_shapes']:
raise RuntimeError("Boolean array indexing (masking) requires data-dependent shapes, but the data_dependent_shapes flag has been disabled for array-api-strict")

elif i.dtype in _integer_dtypes and i.ndim != 0:
raise IndexError(
f"Single-axes index {i} is a non-zero-dimensional "
Expand Down Expand Up @@ -482,10 +486,21 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
) -> types.ModuleType:
if api_version is not None and api_version not in ["2021.12", "2022.12"]:
raise ValueError(f"Unrecognized array API version: {api_version!r}")
if api_version == "2021.12":
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
"""
Return the array_api_strict namespace corresponding to api_version.

The default API version is '2022.12'. Note that '2021.12' is supported,
but currently identical to '2022.12'.

For array_api_strict, calling this function with api_version will set
the API version for the array_api_strict module globally. This can
also be achieved with the
{func}`array_api_strict.set_array_api_strict_flags` function. If you
want to only set the version locally, use the
{class}`array_api_strict.ArrayApiStrictFlags` context manager.

"""
set_array_api_strict_flags(api_version=api_version)
import array_api_strict
return array_api_strict

Expand Down
Loading