Skip to content

Commit 19ffc5f

Browse files
committed
Changes mean reduction to use output data type as sum accumulation type
Mean in-place division now uses the real type for the denominator
1 parent 69fdaa5 commit 19ffc5f

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

dpctl/tensor/_statistical_functions.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import dpctl.tensor._tensor_impl as ti
2323
import dpctl.tensor._tensor_reductions_impl as tri
2424

25-
from ._reduction import _default_reduction_dtype
26-
2725

2826
def _var_impl(x, axis, correction, keepdims):
2927
nd = x.ndim
@@ -233,22 +231,25 @@ def mean(x, axis=None, keepdims=False):
233231
host_tasks_list.append(ht_e1)
234232
s_e.append(r_e)
235233
else:
236-
tmp_dt = _default_reduction_dtype(inp_dt, q)
237234
tmp = dpt.empty(
238-
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
235+
arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
239236
)
240-
ht_e_tmp, r_e = tri._sum_over_axis(
241-
src=arr2, trailing_dims_to_reduce=sum_nd, dst=tmp, sycl_queue=q
237+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
238+
src=arr2, dst=tmp, sycl_queue=q
242239
)
243-
host_tasks_list.append(ht_e_tmp)
240+
host_tasks_list.append(ht_e_cpy)
244241
res = dpt.empty(
245242
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
246243
)
247-
ht_e1, c_e = ti._copy_usm_ndarray_into_usm_ndarray(
248-
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
244+
ht_e_red, r_e = tri._sum_over_axis(
245+
src=tmp,
246+
trailing_dims_to_reduce=sum_nd,
247+
dst=res,
248+
sycl_queue=q,
249+
depends=[cpy_e],
249250
)
250-
host_tasks_list.append(ht_e1)
251-
s_e.append(c_e)
251+
host_tasks_list.append(ht_e_red)
252+
s_e.append(r_e)
252253

253254
if keepdims:
254255
res_shape = res_shape + (1,) * sum_nd
@@ -257,8 +258,9 @@ def mean(x, axis=None, keepdims=False):
257258

258259
res_shape = res.shape
259260
# in-place divide
261+
den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt
260262
nelems_arr = dpt.asarray(
261-
nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
263+
nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q
262264
)
263265
if nelems_arr.shape != res_shape:
264266
nelems_arr = dpt.broadcast_to(nelems_arr, res_shape)

0 commit comments

Comments
 (0)