@@ -733,3 +733,76 @@ def test_copy_via_host_gh_1789():
733
733
dpt .from_dlpack (x_np )
734
734
with pytest .raises (BufferError ):
735
735
dpt .from_dlpack (x_np , device = (14 , 0 ))
736
+
737
+
738
+ class LegacyContainer :
739
+ "Helper class implementing legacy `__dlpack__` protocol"
740
+
741
+ def __init__ (self , array ):
742
+ self ._array = array
743
+
744
+ def __dlpack__ (self , stream = None ):
745
+ return self ._array .__dlpack__ (stream = stream )
746
+
747
+ def __dlpack_device__ (self ):
748
+ return self ._array .__dlpack_device__ ()
749
+
750
+
751
+ class Container :
752
+ "Helper class implementing legacy `__dlpack__` protocol"
753
+
754
+ def __init__ (self , array ):
755
+ self ._array = array
756
+
757
+ def __dlpack__ (
758
+ self , max_version = None , dl_device = None , copy = None , stream = None
759
+ ):
760
+ return self ._array .__dlpack__ (
761
+ max_version = max_version ,
762
+ dl_device = dl_device ,
763
+ copy = copy ,
764
+ stream = stream ,
765
+ )
766
+
767
+ def __dlpack_device__ (self ):
768
+ return self ._array .__dlpack_device__ ()
769
+
770
+
771
+ def test_generic_container_legacy ():
772
+ get_queue_or_skip ()
773
+ C = LegacyContainer (dpt .linspace (0 , 100 , num = 20 , dtype = "int16" ))
774
+
775
+ X = dpt .from_dlpack (C )
776
+ assert isinstance (X , dpt .usm_ndarray )
777
+ assert X ._pointer == C ._array ._pointer
778
+ assert X .sycl_device == C ._array .sycl_device
779
+ assert X .dtype == C ._array .dtype
780
+
781
+ Y = dpt .from_dlpack (C , device = (dpt .DLDeviceType .kDLCPU , 0 ))
782
+ assert isinstance (Y , np .ndarray )
783
+ assert Y .dtype == X .dtype
784
+
785
+ Z = dpt .from_dlpack (C , device = X .device )
786
+ assert isinstance (Z , dpt .usm_ndarray )
787
+ assert Z ._pointer == X ._pointer
788
+ assert Z .device == X .device
789
+
790
+
791
+ def test_generic_container ():
792
+ get_queue_or_skip ()
793
+ C = Container (dpt .linspace (0 , 100 , num = 20 , dtype = "int16" ))
794
+
795
+ X = dpt .from_dlpack (C )
796
+ assert isinstance (X , dpt .usm_ndarray )
797
+ assert X ._pointer == C ._array ._pointer
798
+ assert X .sycl_device == C ._array .sycl_device
799
+ assert X .dtype == C ._array .dtype
800
+
801
+ Y = dpt .from_dlpack (C , device = (dpt .DLDeviceType .kDLCPU , 0 ))
802
+ assert isinstance (Y , np .ndarray )
803
+ assert Y .dtype == X .dtype
804
+
805
+ Z = dpt .from_dlpack (C , device = X .device )
806
+ assert isinstance (Z , dpt .usm_ndarray )
807
+ assert Z ._pointer == X ._pointer
808
+ assert Z .device == X .device
0 commit comments