Skip to content

Commit 130717f

Browse files
committed
Add test for invalid typenums in usm_ndarray-from-pointer functions
1 parent cc8c6d4 commit 130717f

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,75 @@ def test_pyx_capi_make_general():
963963
assert zd_arr._pointer == mat._pointer
964964

965965

966+
def test_pyx_capi_make_fns_invalid_typenum():
967+
q = get_queue_or_skip()
968+
usm_ndarray = dpt.empty(tuple(), dtype="i4", sycl_queue=q)
969+
970+
make_simple_from_ptr = _pyx_capi_fnptr_to_callable(
971+
usm_ndarray,
972+
"UsmNDArray_MakeSimpleFromPtr",
973+
b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
974+
b"DPCTLSyclQueueRef, PyObject *)",
975+
fn_restype=ctypes.py_object,
976+
fn_argtypes=(
977+
ctypes.c_size_t,
978+
ctypes.c_int,
979+
ctypes.c_void_p,
980+
ctypes.c_void_p,
981+
ctypes.py_object,
982+
),
983+
)
984+
985+
nelems = 10
986+
dtype = dpt.int64
987+
arr = dpt.arange(nelems, dtype=dtype, sycl_queue=q)
988+
989+
with pytest.raises(ValueError):
990+
make_simple_from_ptr(
991+
ctypes.c_size_t(nelems),
992+
-1,
993+
arr._pointer,
994+
arr.sycl_queue.addressof_ref(),
995+
arr,
996+
)
997+
998+
make_from_ptr = _pyx_capi_fnptr_to_callable(
999+
usm_ndarray,
1000+
"UsmNDArray_MakeFromPtr",
1001+
b"PyObject *(int, Py_ssize_t const *, int, Py_ssize_t const *, "
1002+
b"DPCTLSyclUSMRef, DPCTLSyclQueueRef, Py_ssize_t, PyObject *)",
1003+
fn_restype=ctypes.py_object,
1004+
fn_argtypes=(
1005+
ctypes.c_int,
1006+
ctypes.POINTER(ctypes.c_ssize_t),
1007+
ctypes.c_int,
1008+
ctypes.POINTER(ctypes.c_ssize_t),
1009+
ctypes.c_void_p,
1010+
ctypes.c_void_p,
1011+
ctypes.c_ssize_t,
1012+
ctypes.py_object,
1013+
),
1014+
)
1015+
c_shape = (ctypes.c_ssize_t * 1)(
1016+
nelems,
1017+
)
1018+
c_strides = (ctypes.c_ssize_t * 1)(
1019+
1,
1020+
)
1021+
with pytest.raises(ValueError):
1022+
make_from_ptr(
1023+
ctypes.c_int(1),
1024+
c_shape,
1025+
-1,
1026+
c_strides,
1027+
arr._pointer,
1028+
arr.sycl_queue.addressof_ref(),
1029+
ctypes.c_ssize_t(0),
1030+
arr,
1031+
)
1032+
del arr
1033+
1034+
9661035
def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
9671036
import sys
9681037

0 commit comments

Comments
 (0)