Skip to content

Commit 4c80f6d

Browse files
committed
Allow __dlpack__ to work with newer versions of NumPy
1 parent ae02e78 commit 4c80f6d

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

array_api_strict/_array_object.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,16 @@ def __dlpack__(
586586
if copy is not _default:
587587
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
588588

589-
# Going to wait for upstream numpy support
590-
if max_version not in [_default, None]:
591-
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592-
if dl_device not in [_default, None]:
593-
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
594-
if copy not in [_default, None]:
595-
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
596-
597-
return self._array.__dlpack__(stream=stream)
589+
if np.__version__ < '2.1':
590+
if max_version not in [_default, None]:
591+
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592+
if dl_device not in [_default, None]:
593+
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
594+
if copy not in [_default, None]:
595+
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
596+
597+
return self._array.__dlpack__(stream=stream)
598+
return self._array.__dlpack__(stream=stream, max_version=max_version, dl_device=dl_device, copy=copy)
598599

599600
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
600601
"""

array_api_strict/tests/test_array_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def dlpack_2023_12(api_version):
460460
a.__dlpack__()
461461

462462

463-
exception = NotImplementedError if api_version >= '2023.12' else ValueError
463+
exception = NotImplementedError if api_version >= '2023.12' and np.__version__ < '2.1' else ValueError
464464
pytest.raises(exception, lambda:
465465
a.__dlpack__(dl_device=CPU_DEVICE))
466466
pytest.raises(exception, lambda:

0 commit comments

Comments
 (0)