Skip to content

Commit 787f80e

Browse files
tensor.asarray support for objects with __usm_ndarray__ attribute added
`asarray` supports objects that implement `__sycl_usm_array_interface__`. It can create a view into USM allocation owned by input object, hence maintains a reference to it. Asynchronous deallocation of such objects in dpctl.tensor functions require manipulating Python object reference counters, and hold GIL. This is a source of dead-locks, and affects performance. This PR adds support for ingesting Python objects that implement __usm_ndarray__ attribute (property) that returns dpt.usm_ndarray object with such a view directly. It is trivial for `dpnp.ndarray` to implement such a property, e.g,. ``` @Property def __usm_ndarray__(self): return self._array_obj ``` With this definition, `dpt.asarray(dpnp_array)` will recognize that the underlying USM allocation is managed by the smart pointer, and asynchronous deallocation will not involve Python objects, avoiding dead-locks and improving performance.
1 parent 9f8f90b commit 787f80e

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def _array_info_dispatch(obj):
6161
if _is_object_with_buffer_protocol(obj):
6262
np_obj = np.array(obj)
6363
return np_obj.shape, np_obj.dtype, _host_set
64+
if hasattr(obj, "__usm_ndarray__"):
65+
usm_ar = getattr(obj, "__usm_ndarray__")
66+
if isinstance(usm_ar, dpt.usm_ndarray):
67+
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
6468
if hasattr(obj, "__sycl_usm_array_interface__"):
6569
usm_ar = _usm_ndarray_from_suai(obj)
6670
return usm_ar.shape, usm_ar.dtype, frozenset([usm_ar.sycl_queue])
@@ -306,6 +310,11 @@ def _usm_types_walker(o, usm_types_list):
306310
if isinstance(o, dpt.usm_ndarray):
307311
usm_types_list.append(o.usm_type)
308312
return
313+
if hasattr(o, "__usm_ndarray__"):
314+
usm_arr = getattr(o, "__usm_ndarray__")
315+
if isinstance(usm_arr, dpt.usm_ndarray):
316+
usm_types_list.append(usm_arr.usm_type)
317+
return
309318
if hasattr(o, "__sycl_usm_array_interface__"):
310319
usm_ar = _usm_ndarray_from_suai(o)
311320
usm_types_list.append(usm_ar.usm_type)
@@ -330,6 +339,11 @@ def _device_copy_walker(seq_o, res, _manager):
330339
)
331340
_manager.add_event_pair(ht_ev, cpy_ev)
332341
return
342+
if hasattr(seq_o, "__usm_ndarray__"):
343+
usm_arr = getattr(seq_o, "__usm_ndarray__")
344+
if isinstance(usm_arr, dpt.usm_ndarray):
345+
_device_copy_walker(usm_arr, res, _manager)
346+
return
333347
if hasattr(seq_o, "__sycl_usm_array_interface__"):
334348
usm_ar = _usm_ndarray_from_suai(seq_o)
335349
exec_q = res.sycl_queue
@@ -361,6 +375,11 @@ def _copy_through_host_walker(seq_o, usm_res):
361375
return
362376
else:
363377
usm_res[...] = seq_o
378+
if hasattr(seq_o, "__usm_ndarray__"):
379+
usm_arr = getattr(seq_o, "__usm_ndarray__")
380+
if isinstance(usm_arr, dpt.usm_ndarray):
381+
_copy_through_host_walker(usm_arr, usm_res)
382+
return
364383
if hasattr(seq_o, "__sycl_usm_array_interface__"):
365384
usm_ar = _usm_ndarray_from_suai(seq_o)
366385
if (
@@ -564,6 +583,17 @@ def asarray(
564583
sycl_queue=sycl_queue,
565584
order=order,
566585
)
586+
if hasattr(obj, "__usm_ndarray__"):
587+
usm_arr = getattr(obj, "__usm_ndarray__")
588+
if isinstance(usm_arr, dpt.usm_ndarray):
589+
return _asarray_from_usm_ndarray(
590+
usm_arr,
591+
dtype=dtype,
592+
copy=copy,
593+
usm_type=usm_type,
594+
sycl_queue=sycl_queue,
595+
order=order,
596+
)
567597
if hasattr(obj, "__sycl_usm_array_interface__"):
568598
ary = _usm_ndarray_from_suai(obj)
569599
return _asarray_from_usm_ndarray(

0 commit comments

Comments
 (0)