Skip to content

Commit f8e7c84

Browse files
authored
Merge pull request #86 from asmeurer/dlpack-fix
Allow __dlpack__ to work with newer versions of NumPy
2 parents ae02e78 + 93201f1 commit f8e7c84

File tree

2 files changed

+42
-24
lines changed

2 files changed

+42
-24
lines changed

array_api_strict/_array_object.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
if TYPE_CHECKING:
4040
from typing import Optional, Tuple, Union, Any
41-
from ._typing import PyCapsule, Device, Dtype
41+
from ._typing import PyCapsule, Dtype
4242
import numpy.typing as npt
4343

4444
import numpy as np
@@ -586,15 +586,24 @@ def __dlpack__(
586586
if copy is not _default:
587587
raise ValueError("The copy argument to __dlpack__ requires at least version 2023.12 of the array API")
588588

589-
# Going to wait for upstream numpy support
590-
if max_version not in [_default, None]:
591-
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592-
if dl_device not in [_default, None]:
593-
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
594-
if copy not in [_default, None]:
595-
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
589+
if np.__version__[0] < '2.1':
590+
if max_version not in [_default, None]:
591+
raise NotImplementedError("The max_version argument to __dlpack__ is not yet implemented")
592+
if dl_device not in [_default, None]:
593+
raise NotImplementedError("The device argument to __dlpack__ is not yet implemented")
594+
if copy not in [_default, None]:
595+
raise NotImplementedError("The copy argument to __dlpack__ is not yet implemented")
596596

597-
return self._array.__dlpack__(stream=stream)
597+
return self._array.__dlpack__(stream=stream)
598+
else:
599+
kwargs = {'stream': stream}
600+
if max_version is not _default:
601+
kwargs['max_version'] = max_version
602+
if dl_device is not _default:
603+
kwargs['dl_device'] = dl_device
604+
if copy is not _default:
605+
kwargs['copy'] = copy
606+
return self._array.__dlpack__(**kwargs)
598607

599608
def __dlpack_device__(self: Array, /) -> Tuple[IntEnum, int]:
600609
"""

array_api_strict/tests/test_array_object.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -460,18 +460,27 @@ def dlpack_2023_12(api_version):
460460
a.__dlpack__()
461461

462462

463-
exception = NotImplementedError if api_version >= '2023.12' else ValueError
464-
pytest.raises(exception, lambda:
465-
a.__dlpack__(dl_device=CPU_DEVICE))
466-
pytest.raises(exception, lambda:
467-
a.__dlpack__(dl_device=None))
468-
pytest.raises(exception, lambda:
469-
a.__dlpack__(max_version=(1, 0)))
470-
pytest.raises(exception, lambda:
471-
a.__dlpack__(max_version=None))
472-
pytest.raises(exception, lambda:
473-
a.__dlpack__(copy=False))
474-
pytest.raises(exception, lambda:
475-
a.__dlpack__(copy=True))
476-
pytest.raises(exception, lambda:
477-
a.__dlpack__(copy=None))
463+
if np.__version__ < '2.1':
464+
exception = NotImplementedError if api_version >= '2023.12' else ValueError
465+
pytest.raises(exception, lambda:
466+
a.__dlpack__(dl_device=CPU_DEVICE))
467+
pytest.raises(exception, lambda:
468+
a.__dlpack__(dl_device=None))
469+
pytest.raises(exception, lambda:
470+
a.__dlpack__(max_version=(1, 0)))
471+
pytest.raises(exception, lambda:
472+
a.__dlpack__(max_version=None))
473+
pytest.raises(exception, lambda:
474+
a.__dlpack__(copy=False))
475+
pytest.raises(exception, lambda:
476+
a.__dlpack__(copy=True))
477+
pytest.raises(exception, lambda:
478+
a.__dlpack__(copy=None))
479+
else:
480+
a.__dlpack__(dl_device=CPU_DEVICE)
481+
a.__dlpack__(dl_device=None)
482+
a.__dlpack__(max_version=(1, 0))
483+
a.__dlpack__(max_version=None)
484+
a.__dlpack__(copy=False)
485+
a.__dlpack__(copy=True)
486+
a.__dlpack__(copy=None)

0 commit comments

Comments
 (0)