diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index fb400178f9..6c5f3888a6 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -202,6 +202,13 @@ def _usm_ndarray_from_suai(obj): buffer=membuf, strides=sua_iface.get("strides", None), ) + _data_field = sua_iface["data"] + if isinstance(_data_field, tuple) and len(_data_field) > 1: + ro_field = _data_field[1] + else: + ro_field = False + if ro_field: + ary.flags["W"] = False return ary diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index d9ce0eff50..6e702832e5 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -2368,3 +2368,29 @@ def test_gh_1201(): c = dpt.flip(dpt.empty(a.shape, dtype=a.dtype)) c[:] = a assert (dpt.asnumpy(c) == a).all() + + +class ObjWithSyclUsmArrayInterface: + def __init__(self, ary): + self._array_obj = ary + + @property + def __sycl_usm_array_interface__(self): + _suai = self._array_obj.__sycl_usm_array_interface__ + return _suai + + +@pytest.mark.parametrize("ro_flag", [True, False]) +def test_asarray_writable_flag(ro_flag): + try: + a = dpt.empty(8) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + + a.flags["W"] = not ro_flag + wrapped = ObjWithSyclUsmArrayInterface(a) + + b = dpt.asarray(wrapped) + + assert b.flags["W"] == (not ro_flag) + assert b._pointer == a._pointer