@@ -556,3 +556,70 @@ def test_as_f_contig_square(dt):
556
556
x3 = dpt .flip (x , axis = 1 )
557
557
y3 = dpt .asarray (x3 , order = "F" )
558
558
assert dpt .all (x3 == y3 )
559
+
560
+
561
+ class MockArrayWithBothProtocols :
562
+ """
563
+ Object that implements both __sycl_usm_array_interface__
564
+ and __usm_ndarray__ properties.
565
+ """
566
+
567
+ def __init__ (self , usm_ar ):
568
+ if not isinstance (usm_ar , dpt .usm_ndarray ):
569
+ raise TypeError
570
+ self ._arr = usm_ar
571
+
572
+ @property
573
+ def __usm_ndarray__ (self ):
574
+ return self ._arr
575
+
576
+ @property
577
+ def __sycl_usm_array_interface__ (self ):
578
+ return self ._arr .__sycl_usm_array_interface__
579
+
580
+
581
+ class MockArrayWithSUAIOnly :
582
+ """
583
+ Object that implements only the
584
+ __sycl_usm_array_interface__ property.
585
+ """
586
+
587
+ def __init__ (self , usm_ar ):
588
+ if not isinstance (usm_ar , dpt .usm_ndarray ):
589
+ raise TypeError
590
+ self ._arr = usm_ar
591
+
592
+ @property
593
+ def __sycl_usm_array_interface__ (self ):
594
+ return self ._arr .__sycl_usm_array_interface__
595
+
596
+
597
+ @pytest .mark .parametrize ("usm_type" , ["shared" , "device" , "host" ])
598
+ def test_asarray_support_for_usm_ndarray_protocol (usm_type ):
599
+ get_queue_or_skip ()
600
+
601
+ x = dpt .arange (256 , dtype = "i4" , usm_type = usm_type )
602
+
603
+ o1 = MockArrayWithBothProtocols (x )
604
+ o2 = MockArrayWithSUAIOnly (x )
605
+
606
+ y1 = dpt .asarray (o1 )
607
+ assert x .sycl_queue == y1 .sycl_queue
608
+ assert x .usm_type == y1 .usm_type
609
+ assert x .dtype == y1 .dtype
610
+ assert y1 .usm_data .reference_obj is None
611
+ assert dpt .all (x == y1 )
612
+
613
+ y2 = dpt .asarray (o2 )
614
+ assert x .sycl_queue == y2 .sycl_queue
615
+ assert x .usm_type == y2 .usm_type
616
+ assert x .dtype == y2 .dtype
617
+ assert not (y2 .usm_data .reference_obj is None )
618
+ assert dpt .all (x == y2 )
619
+
620
+ y3 = dpt .asarray ([o1 , o2 ])
621
+ assert x .sycl_queue == y3 .sycl_queue
622
+ assert x .usm_type == y3 .usm_type
623
+ assert x .dtype == y3 .dtype
624
+ assert y3 .usm_data .reference_obj is None
625
+ assert dpt .all (x [dpt .newaxis , :] == y3 )
0 commit comments