diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 541d1d4fae..e759571790 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -517,9 +517,10 @@ def _nonzero_impl(ary): mask_nelems, dtype=cumsum_dt, sycl_queue=exec_q, order="C" ) mask_count = ti.mask_positions(ary, cumsum, sycl_queue=exec_q) + indexes_dt = ti.default_device_int_type(exec_q.sycl_device) indexes = dpt.empty( (ary.ndim, mask_count), - dtype=cumsum.dtype, + dtype=indexes_dt, usm_type=usm_type, sycl_queue=exec_q, order="C", diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 41688075e0..87d89a1b8d 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1345,3 +1345,15 @@ def test_nonzero_arg_validation(): dpt.nonzero(list()) with pytest.raises(ValueError): dpt.nonzero(dpt.asarray(1)) + + +def test_nonzero_dtype(): + "See gh-1322" + get_queue_or_skip() + x = dpt.ones((3, 4)) + idx, idy = dpt.nonzero(x) + # create array using device's + # default integral data type + ref = dpt.arange(8) + assert idx.dtype == ref.dtype + assert idy.dtype == ref.dtype