Skip to content

Commit 9403f76

Browse files
Merge pull request #1756 from IntelPython/propagate-ro-flag-for-suai-in-asarray
Propagate read-only flag from sycl_usm_array_interface in asarray
2 parents f83f95b + 9a5715f commit 9403f76

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ def _usm_ndarray_from_suai(obj):
202202
buffer=membuf,
203203
strides=sua_iface.get("strides", None),
204204
)
205+
_data_field = sua_iface["data"]
206+
if isinstance(_data_field, tuple) and len(_data_field) > 1:
207+
ro_field = _data_field[1]
208+
else:
209+
ro_field = False
210+
if ro_field:
211+
ary.flags["W"] = False
205212
return ary
206213

207214

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,3 +2368,29 @@ def test_gh_1201():
23682368
c = dpt.flip(dpt.empty(a.shape, dtype=a.dtype))
23692369
c[:] = a
23702370
assert (dpt.asnumpy(c) == a).all()
2371+
2372+
2373+
class ObjWithSyclUsmArrayInterface:
2374+
def __init__(self, ary):
2375+
self._array_obj = ary
2376+
2377+
@property
2378+
def __sycl_usm_array_interface__(self):
2379+
_suai = self._array_obj.__sycl_usm_array_interface__
2380+
return _suai
2381+
2382+
2383+
@pytest.mark.parametrize("ro_flag", [True, False])
2384+
def test_asarray_writable_flag(ro_flag):
2385+
try:
2386+
a = dpt.empty(8)
2387+
except dpctl.SyclDeviceCreationError:
2388+
pytest.skip("No SYCL devices available")
2389+
2390+
a.flags["W"] = not ro_flag
2391+
wrapped = ObjWithSyclUsmArrayInterface(a)
2392+
2393+
b = dpt.asarray(wrapped)
2394+
2395+
assert b.flags["W"] == (not ro_flag)
2396+
assert b._pointer == a._pointer

0 commit comments

Comments
 (0)