Skip to content

Commit 98560b8

Browse files
committed
MAINT: make __array__ raise on python < 3.12
Otherwise, on python 3.11 and below, np.array(array_api_strict_array) becomes a 0D object array.
1 parent 9f232b7 commit 98560b8

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

array_api_strict/_array_object.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,21 @@ def __repr__(self) -> str:
165165
# Instead of `__array__` we now implement the buffer protocol.
166166
# Note that it makes array-apis-strict requiring python>=3.12
167167
def __buffer__(self, flags):
168-
print('__buffer__')
169168
if self._device != CPU_DEVICE:
170169
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
171170
return memoryview(self._array)
172171
def __release_buffer(self, buffer):
173-
print('__release__')
174172
# XXX anything to do here?
173+
pass
174+
175+
def __array__(self, *args, **kwds):
176+
# a stub for python < 3.12; otherwise numpy silently produces object arrays
177+
import sys
178+
minor, major = sys.version_info.minor, sys.version_info.major
179+
if major < 3 or minor < 12:
180+
raise TypeError(
181+
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
182+
)
175183

176184
# These are various helper functions to make the array behavior match the
177185
# spec in places where it either deviates from or is more strict than

array_api_strict/tests/test_array_object.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,9 +541,23 @@ def test_array_conversion():
541541

542542
for device in ("device1", "device2"):
543543
a = ones((2, 3), device=array_api_strict.Device(device))
544-
with pytest.raises(RuntimeError, match="Can not convert array"):
544+
with pytest.raises((RuntimeError, ValueError)):
545545
np.asarray(a)
546546

547+
# __buffer__ should work for now for conversion to numpy
548+
a = ones((2, 3))
549+
na = np.array(a)
550+
assert na.shape == (2, 3)
551+
assert na.dtype == np.float64
552+
553+
@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
554+
reason="conversion to numpy errors out unless python >= 3.12"
555+
)
556+
def test_array_conversion_2():
557+
a = ones((2, 3))
558+
with pytest.raises(TypeError):
559+
np.array(a)
560+
547561

548562
def test_allow_newaxis():
549563
a = ones(5)

0 commit comments

Comments
 (0)