Skip to content

Commit 9f232b7

Browse files
committed
TST: adapt tests for the lack of __array__
1 parent 45c2d1f commit 9f232b7

File tree

3 files changed

+9
-34
lines changed

3 files changed

+9
-34
lines changed

array_api_strict/_creation_functions.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Generator
44
from contextlib import contextmanager
55
from enum import Enum
6-
from typing import TYPE_CHECKING, Literal
6+
from typing import TYPE_CHECKING, Literal, List, Optional, Tuple, Union
77

88
import numpy as np
99

@@ -26,20 +26,6 @@ class Undef(Enum):
2626
_undef = Undef.UNDEF
2727

2828

29-
@contextmanager
30-
def allow_array() -> Generator[None]:
31-
"""
32-
Temporarily enable Array.__array__. This is needed for np.array to parse
33-
list of lists of Array objects.
34-
"""
35-
from . import _array_object
36-
original_value = _array_object._allow_array
37-
try:
38-
_array_object._allow_array = True
39-
yield
40-
finally:
41-
_array_object._allow_array = original_value
42-
4329

4430
def _check_valid_dtype(dtype: DType | None) -> None:
4531
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -123,8 +109,8 @@ def asarray(
123109
# Give a better error message in this case. NumPy would convert this
124110
# to an object array. TODO: This won't handle large integers in lists.
125111
raise OverflowError("Integer out of bounds for array dtypes")
126-
with allow_array():
127-
res = np.array(obj, dtype=_np_dtype, copy=copy)
112+
113+
res = np.array(obj, dtype=_np_dtype, copy=copy)
128114
return Array._new(res, device=device)
129115

130116

array_api_strict/tests/test_array_object.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import operator
23
from builtins import all as all_
34

@@ -526,6 +527,10 @@ def test_array_properties():
526527
assert b.mT.shape == (3, 2)
527528

528529

530+
@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
531+
reason="array conversion relies on buffer protocol, and "
532+
"requires python >= 3.12"
533+
)
529534
def test_array_conversion():
530535
# Check that arrays on the CPU device can be converted to NumPy
531536
# but arrays on other devices can't. Note this is testing the logic in
@@ -539,22 +544,6 @@ def test_array_conversion():
539544
with pytest.raises(RuntimeError, match="Can not convert array"):
540545
np.asarray(a)
541546

542-
def test__array__():
543-
# __array__ should work for now
544-
a = ones((2, 3))
545-
np.array(a)
546-
547-
# Test the _allow_array private global flag for disabling it in the
548-
# future.
549-
from .. import _array_object
550-
original_value = _array_object._allow_array
551-
try:
552-
_array_object._allow_array = False
553-
a = ones((2, 3))
554-
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
555-
np.array(a)
556-
finally:
557-
_array_object._allow_array = original_value
558547

559548
def test_allow_newaxis():
560549
a = ones(5)

array_api_strict/tests/test_creation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import warnings
23

34
from numpy.testing import assert_raises
@@ -97,7 +98,6 @@ def test_asarray_copy():
9798
a[0] = 0
9899
assert all(b[0] == 0)
99100

100-
101101
def test_asarray_list_of_lists():
102102
lst = [[1, 2, 3], [4, 5, 6]]
103103
res = asarray(lst)

0 commit comments

Comments
 (0)