@@ -142,21 +142,24 @@ def _reduction_over_axis(
142
142
"Automatically determined reduction data type does not "
143
143
"have direct implementation"
144
144
)
145
- tmp_dt = _default_reduction_type_fn (inp_dt , q )
146
145
tmp = dpt .empty (
147
- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
146
+ arr2 . shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
148
147
)
149
- ht_e_tmp , r_e = _reduction_fn (
150
- src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
148
+ ht_e_cpy , cpy_e = ti . _copy_usm_ndarray_into_usm_ndarray (
149
+ src = arr2 , dst = tmp , sycl_queue = q
151
150
)
152
- host_tasks_list .append (ht_e_tmp )
151
+ host_tasks_list .append (ht_e_cpy )
153
152
res = dpt .empty (
154
153
res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
155
154
)
156
- ht_e , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
157
- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
155
+ ht_e_red , r_e = _reduction_fn (
156
+ src = tmp ,
157
+ trailing_dims_to_reduce = red_nd ,
158
+ dst = res ,
159
+ sycl_queue = q ,
160
+ depends = [cpy_e ],
158
161
)
159
- host_tasks_list .append (ht_e )
162
+ host_tasks_list .append (ht_e_red )
160
163
161
164
if keepdims :
162
165
res_shape = res_shape + (1 ,) * red_nd
0 commit comments