Skip to content

Commit 829d515

Browse files
authored
Merge pull request #24 from asmeurer/api_version
Update api_version argument to __array_namespace__
2 parents 3092e7b + 9a03ab8 commit 829d515

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

array_api_strict/_array_object.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import operator
1919
from enum import IntEnum
20+
import warnings
21+
2022
from ._creation_functions import asarray
2123
from ._dtypes import (
2224
_DType,
@@ -480,8 +482,10 @@ def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
480482
def __array_namespace__(
481483
self: Array, /, *, api_version: Optional[str] = None
482484
) -> types.ModuleType:
483-
if api_version is not None and not api_version.startswith("2021."):
485+
if api_version is not None and api_version not in ["2021.12", "2022.12"]:
484486
raise ValueError(f"Unrecognized array API version: {api_version!r}")
487+
if api_version == "2021.12":
488+
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
485489
import array_api_strict
486490
return array_api_strict
487491

array_api_strict/tests/test_array_object.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
uint64,
2424
bool as bool_,
2525
)
26-
26+
import array_api_strict
2727

2828
def test_validate_index():
2929
# The indexing tests in the official array API test suite test that the
@@ -398,3 +398,13 @@ def test_array_keys_use_private_array():
398398
key = ones((0, 0), dtype=bool_)
399399
with pytest.raises(IndexError):
400400
a[key]
401+
402+
def test_array_namespace():
403+
a = ones((3, 3))
404+
assert a.__array_namespace__() == array_api_strict
405+
assert a.__array_namespace__(api_version=None) is array_api_strict
406+
assert a.__array_namespace__(api_version="2022.12") is array_api_strict
407+
with pytest.warns(UserWarning):
408+
assert a.__array_namespace__(api_version="2021.12") is array_api_strict
409+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
410+
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2023.12"))

0 commit comments

Comments
 (0)