Skip to content

Commit 55e5a61

Browse files
committed
MAINT: make __array__ raise on python < 3.12
Otherwise, on python 3.11 and below, np.array(array_api_strict_array) becomes a 0D object array.
1 parent 9f232b7 commit 55e5a61

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

array_api_strict/_array_object.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,18 @@ def __repr__(self) -> str:
165165
# Instead of `__array__` we now implement the buffer protocol.
166166
# Note that it makes array-apis-strict requiring python>=3.12
167167
def __buffer__(self, flags):
168-
print('__buffer__')
169168
if self._device != CPU_DEVICE:
170169
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
171170
return memoryview(self._array)
172171
def __release_buffer(self, buffer):
173-
print('__release__')
174172
# XXX anything to do here?
173+
pass
174+
175+
def __array__(self, *args, **kwds):
176+
# a stub for python < 3.12; otherwise numpy silently produces object arrays
177+
raise TypeError(
178+
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
179+
)
175180

176181
# These are various helper functions to make the array behavior match the
177182
# spec in places where it either deviates from or is more strict than

array_api_strict/tests/test_array_object.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,20 @@ def test_array_conversion():
544544
with pytest.raises(RuntimeError, match="Can not convert array"):
545545
np.asarray(a)
546546

547+
# __buffer__ should work for now for conversion to numpy
548+
a = ones((2, 3))
549+
na = np.array(a)
550+
assert na.shape == (2, 3)
551+
assert na.dtype == np.float64
552+
553+
@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
554+
reason="conversion to numpy errors out unless python >= 3.12"
555+
)
556+
def test_array_conversion_2():
557+
a = ones((2, 3))
558+
with pytest.raises(TypeError):
559+
np.array(a)
560+
547561

548562
def test_allow_newaxis():
549563
a = ones(5)

0 commit comments

Comments
 (0)