@@ -826,3 +826,38 @@ def test_generic_container():
826
826
assert isinstance (Z , dpt .usm_ndarray )
827
827
assert Z ._pointer == X ._pointer
828
828
assert Z .device == X .device
829
+
830
+
831
+ def test_sycldevice_to_dldevice (all_root_devices ):
832
+ for sycl_dev in all_root_devices :
833
+ dev = dpt .sycldevice_to_dldevice (sycl_dev )
834
+ assert type (dev ) is tuple
835
+ assert len (dev ) == 2
836
+ assert dev [0 ] == device_oneAPI
837
+ assert dev [1 ] == all_root_devices .index (sycl_dev )
838
+
839
+
840
+ def test_dldevice_to_sycldevice (all_root_devices ):
841
+ for sycl_dev in all_root_devices :
842
+ dldev = dpt .empty (0 , device = sycl_dev ).__dlpack_device__ ()
843
+ dev = dpt .dldevice_to_sycldevice (dldev )
844
+ assert type (dev ) is dpctl .SyclDevice
845
+ assert dev == all_root_devices [dldev [1 ]]
846
+
847
+
848
+ def test_dldevice_conversion_arg_validation ():
849
+ bad_dldevice_type = (dpt .DLDeviceType .kDLCPU , 0 )
850
+ with pytest .raises (ValueError ):
851
+ dpt .dldevice_to_sycldevice (bad_dldevice_type )
852
+
853
+ bad_dldevice_len = bad_dldevice_type + (0 ,)
854
+ with pytest .raises (ValueError ):
855
+ dpt .dldevice_to_sycldevice (bad_dldevice_len )
856
+
857
+ bad_dldevice = dict ()
858
+ with pytest .raises (TypeError ):
859
+ dpt .dldevice_to_sycldevice (bad_dldevice )
860
+
861
+ bad_sycldevice = dict ()
862
+ with pytest .raises (TypeError ):
863
+ dpt .sycldevice_to_dldevice (bad_sycldevice )
0 commit comments