Skip to content

Commit 6c3b7d6

Browse files
committed
Temporarily enable __array__ in asarray so that parsing list of lists of Array can work
1 parent 38551c6 commit 6c3b7d6

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

array_api_strict/_array_object.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __repr__(self):
5353

5454
_default = object()
5555

56+
_allow_array = False
57+
5658
class Array:
5759
"""
5860
n-d array object for the array API namespace.
@@ -135,6 +137,22 @@ def __repr__(self: Array, /) -> str:
135137
# lead to code assuming np.asarray(other_array) would always work in the
136138
# standard.
137139
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
140+
# We have to allow this to be internally enabled as there's no other
141+
# easy way to parse a list of Array objects in asarray().
142+
if _allow_array:
143+
# copy keyword is new in 2.0.0; for older versions don't use it
144+
# retry without that keyword.
145+
if np.__version__[0] < '2':
146+
return np.asarray(self._array, dtype=dtype)
147+
elif np.__version__.startswith('2.0.0-dev0'):
148+
# Handle dev version for which we can't know based on version
149+
# number whether or not the copy keyword is supported.
150+
try:
151+
return np.asarray(self._array, dtype=dtype, copy=copy)
152+
except TypeError:
153+
return np.asarray(self._array, dtype=dtype)
154+
else:
155+
return np.asarray(self._array, dtype=dtype, copy=copy)
138156
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")
139157

140158
# These are various helper functions to make the array behavior match the

array_api_strict/_creation_functions.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
3+
from contextlib import contextmanager
44
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
55

66
if TYPE_CHECKING:
@@ -16,6 +16,19 @@
1616

1717
import numpy as np
1818

19+
@contextmanager
20+
def allow_array():
21+
"""
22+
Temporarily enable Array.__array__. This is needed for np.array to parse
23+
list of lists of Array objects.
24+
"""
25+
from . import _array_object
26+
original_value = _array_object._allow_array
27+
try:
28+
_array_object._allow_array = True
29+
yield
30+
finally:
31+
_array_object._allow_array = original_value
1932

2033
def _check_valid_dtype(dtype):
2134
# Note: Only spelling dtypes as the dtype objects is supported.
@@ -94,7 +107,8 @@ def asarray(
94107
# Give a better error message in this case. NumPy would convert this
95108
# to an object array. TODO: This won't handle large integers in lists.
96109
raise OverflowError("Integer out of bounds for array dtypes")
97-
res = np.array(obj, dtype=_np_dtype, copy=copy)
110+
with allow_array():
111+
res = np.array(obj, dtype=_np_dtype, copy=copy)
98112
return Array._new(res)
99113

100114

array_api_strict/tests/test_creation_functions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
zeros,
2323
zeros_like,
2424
)
25-
from .._dtypes import float32, float64
25+
from .._dtypes import int16, float32, float64
2626
from .._array_object import Array, CPU_DEVICE
2727
from .._flags import set_array_api_strict_flags
2828

@@ -97,6 +97,19 @@ def test_asarray_copy():
9797
a[0] = 0
9898
assert all(b[0] == 0)
9999

100+
def test_asarray_list_of_lists():
101+
a = asarray(1, dtype=int16)
102+
b = asarray([1], dtype=int16)
103+
res = asarray([a, a])
104+
assert res.shape == (2,)
105+
assert res.dtype == int16
106+
assert all(res == asarray([1, 1]))
107+
108+
res = asarray([b, b])
109+
assert res.shape == (2, 1)
110+
assert res.dtype == int16
111+
assert all(res == asarray([[1], [1]]))
112+
100113
def test_arange_errors():
101114
arange(1, device=CPU_DEVICE) # Doesn't error
102115
assert_raises(ValueError, lambda: arange(1, device="cpu"))

0 commit comments

Comments
 (0)