diff --git a/dpctl/tensor/_set_functions.py b/dpctl/tensor/_set_functions.py index 2e2df751a9..bbba301da4 100644 --- a/dpctl/tensor/_set_functions.py +++ b/dpctl/tensor/_set_functions.py @@ -425,8 +425,7 @@ def unique_inverse(x): ) _manager.add_event_pair(ht_ev, sub_ev) - inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 - inv = dpt.empty_like(x, dtype=inv_dt, order="C") + inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, @@ -608,8 +607,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: ) _manager.add_event_pair(ht_ev, sub_ev) - inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32 - inv = dpt.empty_like(x, dtype=inv_dt, order="C") + inv = dpt.empty_like(x, dtype=ind_dt, order="C") ht_ev, ssl_ev = _searchsorted_left( hay=unique_vals, needles=x, diff --git a/dpctl/tests/test_usm_ndarray_unique.py b/dpctl/tests/test_usm_ndarray_unique.py index f3504ee032..fcd55fdfc1 100644 --- a/dpctl/tests/test_usm_ndarray_unique.py +++ b/dpctl/tests/test_usm_ndarray_unique.py @@ -321,3 +321,25 @@ def test_set_functions_compute_follows_data(): assert ind.sycl_queue == q assert inv_ind.sycl_queue == q assert uc.sycl_queue == q + + +def test_gh_1738(): + get_queue_or_skip() + + ones = dpt.ones(10, dtype="i8") + iota = dpt.arange(10, dtype="i8") + + assert ones.device == iota.device + + dpt_info = dpt.__array_namespace_info__() + ind_dt = dpt_info.default_dtypes(device=ones.device)["indexing"] + + dt = dpt.unique_inverse(ones).inverse_indices.dtype + assert dt == ind_dt + dt = dpt.unique_all(ones).inverse_indices.dtype + assert dt == ind_dt + + dt = dpt.unique_inverse(iota).inverse_indices.dtype + assert dt == ind_dt + dt = dpt.unique_all(iota).inverse_indices.dtype + assert dt == ind_dt