22
22
import dpctl .tensor ._tensor_impl as ti
23
23
import dpctl .tensor ._tensor_reductions_impl as tri
24
24
25
- from ._reduction import _default_reduction_dtype
26
-
27
25
28
26
def _var_impl (x , axis , correction , keepdims ):
29
27
nd = x .ndim
@@ -233,22 +231,25 @@ def mean(x, axis=None, keepdims=False):
233
231
host_tasks_list .append (ht_e1 )
234
232
s_e .append (r_e )
235
233
else :
236
- tmp_dt = _default_reduction_dtype (inp_dt , q )
237
234
tmp = dpt .empty (
238
- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
235
+ arr2 . shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
239
236
)
240
- ht_e_tmp , r_e = tri . _sum_over_axis (
241
- src = arr2 , trailing_dims_to_reduce = sum_nd , dst = tmp , sycl_queue = q
237
+ ht_e_cpy , cpy_e = ti . _copy_usm_ndarray_into_usm_ndarray (
238
+ src = arr2 , dst = tmp , sycl_queue = q
242
239
)
243
- host_tasks_list .append (ht_e_tmp )
240
+ host_tasks_list .append (ht_e_cpy )
244
241
res = dpt .empty (
245
242
res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
246
243
)
247
- ht_e1 , c_e = ti ._copy_usm_ndarray_into_usm_ndarray (
248
- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
244
+ ht_e_red , r_e = tri ._sum_over_axis (
245
+ src = tmp ,
246
+ trailing_dims_to_reduce = sum_nd ,
247
+ dst = res ,
248
+ sycl_queue = q ,
249
+ depends = [cpy_e ],
249
250
)
250
- host_tasks_list .append (ht_e1 )
251
- s_e .append (c_e )
251
+ host_tasks_list .append (ht_e_red )
252
+ s_e .append (r_e )
252
253
253
254
if keepdims :
254
255
res_shape = res_shape + (1 ,) * sum_nd
@@ -257,8 +258,9 @@ def mean(x, axis=None, keepdims=False):
257
258
258
259
res_shape = res .shape
259
260
# in-place divide
261
+ den_dt = dpt .finfo (res_dt ).dtype if res_dt .kind == "c" else res_dt
260
262
nelems_arr = dpt .asarray (
261
- nelems , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
263
+ nelems , dtype = den_dt , usm_type = res_usm_type , sycl_queue = q
262
264
)
263
265
if nelems_arr .shape != res_shape :
264
266
nelems_arr = dpt .broadcast_to (nelems_arr , res_shape )
0 commit comments