Skip to content

Commit 1b0c979

Browse files
Add tests for importing generic legacy and generic modern containers
1 parent 424e4c8 commit 1b0c979

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,76 @@ def test_copy_via_host_gh_1789():
733733
dpt.from_dlpack(x_np)
734734
with pytest.raises(BufferError):
735735
dpt.from_dlpack(x_np, device=(14, 0))
736+
737+
738+
class LegacyContainer:
739+
"Helper class implementing legacy `__dlpack__` protocol"
740+
741+
def __init__(self, array):
742+
self._array = array
743+
744+
def __dlpack__(self, stream=None):
745+
return self._array.__dlpack__(stream=stream)
746+
747+
def __dlpack_device__(self):
748+
return self._array.__dlpack_device__()
749+
750+
751+
class Container:
752+
"Helper class implementing legacy `__dlpack__` protocol"
753+
754+
def __init__(self, array):
755+
self._array = array
756+
757+
def __dlpack__(
758+
self, max_version=None, dl_device=None, copy=None, stream=None
759+
):
760+
return self._array.__dlpack__(
761+
max_version=max_version,
762+
dl_device=dl_device,
763+
copy=copy,
764+
stream=stream,
765+
)
766+
767+
def __dlpack_device__(self):
768+
return self._array.__dlpack_device__()
769+
770+
771+
def test_generic_container_legacy():
772+
get_queue_or_skip()
773+
C = LegacyContainer(dpt.linspace(0, 100, num=20, dtype="int16"))
774+
775+
X = dpt.from_dlpack(C)
776+
assert isinstance(X, dpt.usm_ndarray)
777+
assert X._pointer == C._array._pointer
778+
assert X.sycl_device == C._array.sycl_device
779+
assert X.dtype == C._array.dtype
780+
781+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
782+
assert isinstance(Y, np.ndarray)
783+
assert Y.dtype == X.dtype
784+
785+
Z = dpt.from_dlpack(C, device=X.device)
786+
assert isinstance(Z, dpt.usm_ndarray)
787+
assert Z._pointer == X._pointer
788+
assert Z.device == X.device
789+
790+
791+
def test_generic_container():
792+
get_queue_or_skip()
793+
C = Container(dpt.linspace(0, 100, num=20, dtype="int16"))
794+
795+
X = dpt.from_dlpack(C)
796+
assert isinstance(X, dpt.usm_ndarray)
797+
assert X._pointer == C._array._pointer
798+
assert X.sycl_device == C._array.sycl_device
799+
assert X.dtype == C._array.dtype
800+
801+
Y = dpt.from_dlpack(C, device=(dpt.DLDeviceType.kDLCPU, 0))
802+
assert isinstance(Y, np.ndarray)
803+
assert Y.dtype == X.dtype
804+
805+
Z = dpt.from_dlpack(C, device=X.device)
806+
assert isinstance(Z, dpt.usm_ndarray)
807+
assert Z._pointer == X._pointer
808+
assert Z.device == X.device

0 commit comments

Comments
 (0)