Skip to content

Commit cb4fd57

Browse files
committed
Add tests for dldevice and sycldevice interchange functions
1 parent c529b29 commit cb4fd57

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,3 +826,38 @@ def test_generic_container():
826826
assert isinstance(Z, dpt.usm_ndarray)
827827
assert Z._pointer == X._pointer
828828
assert Z.device == X.device
829+
830+
831+
def test_sycldevice_to_dldevice(all_root_devices):
832+
for sycl_dev in all_root_devices:
833+
dev = dpt.sycldevice_to_dldevice(sycl_dev)
834+
assert type(dev) is tuple
835+
assert len(dev) == 2
836+
assert dev[0] == device_oneAPI
837+
assert dev[1] == all_root_devices.index(sycl_dev)
838+
839+
840+
def test_dldevice_to_sycldevice(all_root_devices):
841+
for sycl_dev in all_root_devices:
842+
dldev = dpt.empty(0, device=sycl_dev).__dlpack_device__()
843+
dev = dpt.dldevice_to_sycldevice(dldev)
844+
assert type(dev) is dpctl.SyclDevice
845+
assert dev == all_root_devices[dldev[1]]
846+
847+
848+
def test_dldevice_conversion_arg_validation():
849+
bad_dldevice_type = (dpt.DLDeviceType.kDLCPU, 0)
850+
with pytest.raises(ValueError):
851+
dpt.dldevice_to_sycldevice(bad_dldevice_type)
852+
853+
bad_dldevice_len = bad_dldevice_type + (0,)
854+
with pytest.raises(ValueError):
855+
dpt.dldevice_to_sycldevice(bad_dldevice_len)
856+
857+
bad_dldevice = dict()
858+
with pytest.raises(TypeError):
859+
dpt.dldevice_to_sycldevice(bad_dldevice)
860+
861+
bad_sycldevice = dict()
862+
with pytest.raises(TypeError):
863+
dpt.sycldevice_to_dldevice(bad_sycldevice)

0 commit comments

Comments
 (0)