Skip to content

Commit 486e06d

Browse files
Fixed gh-1468
Function _reduce_over_axis promotes input array to requested result data type and carries out reduction computation in that data type.
1 parent 07c075b commit 486e06d

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

dpctl/tensor/_reduction.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,24 @@ def _reduction_over_axis(
142142
"Automatically determined reduction data type does not "
143143
"have direct implementation"
144144
)
145-
tmp_dt = _default_reduction_type_fn(inp_dt, q)
146145
tmp = dpt.empty(
147-
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
146+
arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
148147
)
149-
ht_e_tmp, r_e = _reduction_fn(
150-
src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
148+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
149+
src=arr2, dst=tmp, sycl_queue=q
151150
)
152-
host_tasks_list.append(ht_e_tmp)
151+
host_tasks_list.append(ht_e_cpy)
153152
res = dpt.empty(
154153
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
155154
)
156-
ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray(
157-
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
155+
ht_e_red, r_e = _reduction_fn(
156+
src=tmp,
157+
trailing_dims_to_reduce=red_nd,
158+
dst=res,
159+
sycl_queue=q,
160+
depends=[cpy_e],
158161
)
159-
host_tasks_list.append(ht_e)
162+
host_tasks_list.append(ht_e_red)
160163

161164
if keepdims:
162165
res_shape = res_shape + (1,) * red_nd

0 commit comments

Comments
 (0)