Skip to content

Commit 878ba20

Browse files
Add test for dpt.asarray for objects with __usm_ndarray__ property
1 parent 787f80e commit 878ba20

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

dpctl/tests/test_tensor_asarray.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,70 @@ def test_as_f_contig_square(dt):
556556
x3 = dpt.flip(x, axis=1)
557557
y3 = dpt.asarray(x3, order="F")
558558
assert dpt.all(x3 == y3)
559+
560+
561+
class MockArrayWithBothProtocols:
562+
"""
563+
Object that implements both __sycl_usm_array_interface__
564+
and __usm_ndarray__ properties.
565+
"""
566+
567+
def __init__(self, usm_ar):
568+
if not isinstance(usm_ar, dpt.usm_ndarray):
569+
raise TypeError
570+
self._arr = usm_ar
571+
572+
@property
573+
def __usm_ndarray__(self):
574+
return self._arr
575+
576+
@property
577+
def __sycl_usm_array_interface__(self):
578+
return self._arr.__sycl_usm_array_interface__
579+
580+
581+
class MockArrayWithSUAIOnly:
582+
"""
583+
Object that implements only the
584+
__sycl_usm_array_interface__ property.
585+
"""
586+
587+
def __init__(self, usm_ar):
588+
if not isinstance(usm_ar, dpt.usm_ndarray):
589+
raise TypeError
590+
self._arr = usm_ar
591+
592+
@property
593+
def __sycl_usm_array_interface__(self):
594+
return self._arr.__sycl_usm_array_interface__
595+
596+
597+
@pytest.mark.parametrize("usm_type", ["shared", "device", "host"])
598+
def test_asarray_support_for_usm_ndarray_protocol(usm_type):
599+
get_queue_or_skip()
600+
601+
x = dpt.arange(256, dtype="i4", usm_type=usm_type)
602+
603+
o1 = MockArrayWithBothProtocols(x)
604+
o2 = MockArrayWithSUAIOnly(x)
605+
606+
y1 = dpt.asarray(o1)
607+
assert x.sycl_queue == y1.sycl_queue
608+
assert x.usm_type == y1.usm_type
609+
assert x.dtype == y1.dtype
610+
assert y1.usm_data.reference_obj is None
611+
assert dpt.all(x == y1)
612+
613+
y2 = dpt.asarray(o2)
614+
assert x.sycl_queue == y2.sycl_queue
615+
assert x.usm_type == y2.usm_type
616+
assert x.dtype == y2.dtype
617+
assert not (y2.usm_data.reference_obj is None)
618+
assert dpt.all(x == y2)
619+
620+
y3 = dpt.asarray([o1, o2])
621+
assert x.sycl_queue == y3.sycl_queue
622+
assert x.usm_type == y3.usm_type
623+
assert x.dtype == y3.dtype
624+
assert y3.usm_data.reference_obj is None
625+
assert dpt.all(x[dpt.newaxis, :] == y3)

0 commit comments

Comments
 (0)