Skip to content

Commit bcef07a

Browse files
Merge pull request #1969 from IntelPython/stream-keyword-validation
Add stream argument validation
2 parents a86da4b + 74066bb commit bcef07a

File tree

2 files changed

+49
-37
lines changed

2 files changed

+49
-37
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ cdef bint _is_host_cpu(object dl_device):
149149
return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0)
150150

151151

152+
cdef void _validate_and_use_stream(object stream, c_dpctl.SyclQueue self_queue) except *:
153+
if (stream is None or stream == self_queue):
154+
pass
155+
else:
156+
if not isinstance(stream, dpctl.SyclQueue):
157+
raise TypeError(
158+
"stream argument type was expected to be dpctl.SyclQueue,"
159+
f" got {type(stream)} instead"
160+
)
161+
ev = self_queue.submit_barrier()
162+
stream.submit_barrier(dependent_events=[ev])
163+
164+
152165
cdef class usm_ndarray:
153166
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
154167
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1025,12 +1038,7 @@ cdef class usm_ndarray:
10251038
cdef c_dpmem._Memory arr_buf
10261039
d = Device.create_device(target_device)
10271040

1028-
if (stream is None or not isinstance(stream, dpctl.SyclQueue) or
1029-
stream == self.sycl_queue):
1030-
pass
1031-
else:
1032-
ev = self.sycl_queue.submit_barrier()
1033-
stream.submit_barrier(dependent_events=[ev])
1041+
_validate_and_use_stream(stream, self.sycl_queue)
10341042

10351043
if (d.sycl_context == self.sycl_context):
10361044
arr_buf = <c_dpmem._Memory> self.usm_data
@@ -1203,12 +1211,7 @@ cdef class usm_ndarray:
12031211
# legacy path for DLManagedTensor
12041212
# copy kwarg ignored because copy flag can't be set
12051213
_caps = c_dlpack.to_dlpack_capsule(self)
1206-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1207-
stream == self.sycl_queue):
1208-
pass
1209-
else:
1210-
ev = self.sycl_queue.submit_barrier()
1211-
stream.submit_barrier(dependent_events=[ev])
1214+
_validate_and_use_stream(stream, self.sycl_queue)
12121215
return _caps
12131216
else:
12141217
if not isinstance(max_version, tuple) or len(max_version) != 2:
@@ -1250,12 +1253,7 @@ cdef class usm_ndarray:
12501253
copy = False
12511254
# TODO: strategy for handling stream on different device from dl_device
12521255
if copy:
1253-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1254-
stream == self.sycl_queue):
1255-
pass
1256-
else:
1257-
ev = self.sycl_queue.submit_barrier()
1258-
stream.submit_barrier(dependent_events=[ev])
1256+
_validate_and_use_stream(stream, self.sycl_queue)
12591257
nbytes = self.usm_data.nbytes
12601258
copy_buffer = type(self.usm_data)(
12611259
nbytes, queue=self.sycl_queue
@@ -1272,22 +1270,12 @@ cdef class usm_ndarray:
12721270
_caps = c_dlpack.to_dlpack_versioned_capsule(_copied_arr, copy)
12731271
else:
12741272
_caps = c_dlpack.to_dlpack_versioned_capsule(self, copy)
1275-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1276-
stream == self.sycl_queue):
1277-
pass
1278-
else:
1279-
ev = self.sycl_queue.submit_barrier()
1280-
stream.submit_barrier(dependent_events=[ev])
1273+
_validate_and_use_stream(stream, self.sycl_queue)
12811274
return _caps
12821275
else:
12831276
# legacy path for DLManagedTensor
12841277
_caps = c_dlpack.to_dlpack_capsule(self)
1285-
if (stream is None or type(stream) is not dpctl.SyclQueue or
1286-
stream == self.sycl_queue):
1287-
pass
1288-
else:
1289-
ev = self.sycl_queue.submit_barrier()
1290-
stream.submit_barrier(dependent_events=[ev])
1278+
_validate_and_use_stream(stream, self.sycl_queue)
12911279
return _caps
12921280

12931281
def __dlpack_device__(self):
@@ -1555,17 +1543,17 @@ cdef class usm_ndarray:
15551543
def __array__(self, dtype=None, /, *, copy=None):
15561544
"""NumPy's array protocol method to disallow implicit conversion.
15571545
1558-
Without this definition, `numpy.asarray(usm_ar)` converts
1559-
usm_ndarray instance into NumPy array with data type `object`
1560-
and every element being 0d usm_ndarray.
1546+
Without this definition, `numpy.asarray(usm_ar)` converts
1547+
usm_ndarray instance into NumPy array with data type `object`
1548+
and every element being 0d usm_ndarray.
15611549
15621550
https://github.com/IntelPython/dpctl/pull/1384#issuecomment-1707212972
1563-
"""
1551+
"""
15641552
raise TypeError(
15651553
"Implicit conversion to a NumPy array is not allowed. "
1566-
"Use `dpctl.tensor.asnumpy` to copy data from this "
1567-
"`dpctl.tensor.usm_ndarray` instance to NumPy array"
1568-
)
1554+
"Use `dpctl.tensor.asnumpy` to copy data from this "
1555+
"`dpctl.tensor.usm_ndarray` instance to NumPy array"
1556+
)
15691557

15701558

15711559
cdef usm_ndarray _real_view(usm_ndarray ary):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,30 @@ def test_to_device():
13801380
assert Y.sycl_device == dev
13811381

13821382

1383+
def test_to_device_stream_validation():
1384+
try:
1385+
X = dpt.usm_ndarray(1, "f4")
1386+
except dpctl.SyclDeviceCreationError:
1387+
pytest.skip("No SYCL devices available")
1388+
# invalid type of stream keyword
1389+
with pytest.raises(TypeError):
1390+
X.to_device(X.sycl_queue, stream=dict())
1391+
# stream is keyword-only arg
1392+
with pytest.raises(TypeError):
1393+
X.to_device(X.sycl_queue, X.sycl_queue)
1394+
1395+
1396+
def test_to_device_stream_use():
1397+
try:
1398+
X = dpt.usm_ndarray(1, "f4")
1399+
except dpctl.SyclDeviceCreationError:
1400+
pytest.skip("No SYCL devices available")
1401+
q1 = dpctl.SyclQueue(
1402+
X.sycl_context, X.sycl_device, property="enable_profiling"
1403+
)
1404+
X.to_device(q1, stream=q1)
1405+
1406+
13831407
def test_to_device_migration():
13841408
q1 = get_queue_or_skip() # two distinct copies of default-constructed queue
13851409
q2 = get_queue_or_skip()

0 commit comments

Comments
 (0)