Skip to content

Commit 9ad1bb5

Browse files
authored
Implement __usm_ndarray__ protocol (#2261)
The PR is intended to adopt to dpctl changes implemented in [dpctl#1959](IntelPython/dpctl#1959). It implements support of `__usm_ndarray__` protocol for `dpnp.ndarray` and returns a property with `dpctl.tensor.usm_ndarray` instance corresponding to the content of the array object. This property is intended to speed-up conversion from `dpnp.ndarray` to `dpt.usm_ndarray` in `x=dpt.asarray(dpnp_array_obj)`. The input object that implements `__usm_ndarray__` is recognized as owner of USM allocation that is managed by a smart pointer, and asynchronous deallocation of `x` need not involve GIL.
1 parent 6cc2348 commit 9ad1bb5

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

dpnp/dpnp_array.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,25 @@ def __truediv__(self, other):
605605
"""Return ``self/value``."""
606606
return dpnp.true_divide(self, other)
607607

608+
@property
609+
def __usm_ndarray__(self):
610+
"""
611+
Property to support `__usm_ndarray__` protocol.
612+
613+
It assumes to return :class:`dpctl.tensor.usm_ndarray` instance
614+
corresponding to the content of the object.
615+
616+
This property is intended to speed-up conversion from
617+
:class:`dpnp.ndarray` to :class:`dpctl.tensor.usm_ndarray` passed
618+
into `dpctl.tensor.asarray` function. The input object that implements
619+
`__usm_ndarray__` protocol is recognized as owner of USM allocation
620+
that is managed by a smart pointer, and asynchronous deallocation
621+
will not involve GIL.
622+
623+
"""
624+
625+
return self._array_obj
626+
608627
def __xor__(self, other):
609628
"""Return ``self^value``."""
610629
return dpnp.bitwise_xor(self, other)

dpnp/tests/test_ndarray.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,18 @@ def test_error(self):
176176
ia.item()
177177

178178

179+
class TestUsmNdarrayProtocol:
180+
def test_basic(self):
181+
a = dpnp.arange(256, dtype=dpnp.int64)
182+
usm_a = dpt.asarray(a)
183+
184+
assert a.sycl_queue == usm_a.sycl_queue
185+
assert a.usm_type == usm_a.usm_type
186+
assert a.dtype == usm_a.dtype
187+
assert usm_a.usm_data.reference_obj is None
188+
assert (a == usm_a).all()
189+
190+
179191
def test_print_dpnp_int():
180192
result = repr(dpnp.array([1, 0, 2, -3, -1, 2, 21, -9], dtype="i4"))
181193
expected = "array([ 1, 0, 2, -3, -1, 2, 21, -9], dtype=int32)"

0 commit comments

Comments
 (0)