Skip to content

Commit 28cc7bd

Browse files
Merge pull request #1741 from IntelPython/fix-for-gh-1738
Fix for gh 1738
2 parents e35dfa8 + 7e2c43c commit 28cc7bd

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

dpctl/tensor/_set_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ def unique_inverse(x):
425425
)
426426
_manager.add_event_pair(ht_ev, sub_ev)
427427

428-
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
429-
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
428+
inv = dpt.empty_like(x, dtype=ind_dt, order="C")
430429
ht_ev, ssl_ev = _searchsorted_left(
431430
hay=unique_vals,
432431
needles=x,
@@ -608,8 +607,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult:
608607
)
609608
_manager.add_event_pair(ht_ev, sub_ev)
610609

611-
inv_dt = dpt.int64 if x.size > dpt.iinfo(dpt.int32).max else dpt.int32
612-
inv = dpt.empty_like(x, dtype=inv_dt, order="C")
610+
inv = dpt.empty_like(x, dtype=ind_dt, order="C")
613611
ht_ev, ssl_ev = _searchsorted_left(
614612
hay=unique_vals,
615613
needles=x,

dpctl/tests/test_usm_ndarray_unique.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,25 @@ def test_set_functions_compute_follows_data():
321321
assert ind.sycl_queue == q
322322
assert inv_ind.sycl_queue == q
323323
assert uc.sycl_queue == q
324+
325+
326+
def test_gh_1738():
327+
get_queue_or_skip()
328+
329+
ones = dpt.ones(10, dtype="i8")
330+
iota = dpt.arange(10, dtype="i8")
331+
332+
assert ones.device == iota.device
333+
334+
dpt_info = dpt.__array_namespace_info__()
335+
ind_dt = dpt_info.default_dtypes(device=ones.device)["indexing"]
336+
337+
dt = dpt.unique_inverse(ones).inverse_indices.dtype
338+
assert dt == ind_dt
339+
dt = dpt.unique_all(ones).inverse_indices.dtype
340+
assert dt == ind_dt
341+
342+
dt = dpt.unique_inverse(iota).inverse_indices.dtype
343+
assert dt == ind_dt
344+
dt = dpt.unique_all(iota).inverse_indices.dtype
345+
assert dt == ind_dt

0 commit comments

Comments
 (0)