@@ -93,16 +93,13 @@ def _var_impl(x, axis, correction, keepdims):
93
93
)
94
94
# divide in-place to get mean
95
95
mean_ary_shape = mean_ary .shape
96
- nelems_ary = dpt .asarray (
97
- nelems , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
98
- )
99
- if nelems_ary .shape != mean_ary_shape :
100
- nelems_ary = dpt .broadcast_to (nelems_ary , mean_ary_shape )
96
+
101
97
dep_evs = _manager .submitted_events
102
- ht_e2 , d_e1 = tei ._divide_inplace (
103
- lhs = mean_ary , rhs = nelems_ary , sycl_queue = q , depends = dep_evs
98
+ ht_e2 , d_e1 = tei ._divide_by_scalar (
99
+ src = mean_ary , scalar = nelems , dst = mean_ary , sycl_queue = q , depends = dep_evs
104
100
)
105
101
_manager .add_event_pair (ht_e2 , d_e1 )
102
+
106
103
# subtract mean from original array to get deviations
107
104
dev_ary = dpt .empty_like (buf )
108
105
if mean_ary_shape != buf .shape :
@@ -144,17 +141,18 @@ def _var_impl(x, axis, correction, keepdims):
144
141
res_shape = res .shape
145
142
# when nelems - correction <= 0, yield nans
146
143
div = max (nelems - correction , 0 )
147
- if not div :
148
- div = dpt .nan
149
- div_ary = dpt .asarray (
150
- div , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
151
- )
152
- # divide in-place again
153
- if div_ary .shape != res_shape :
154
- div_ary = dpt .broadcast_to (div_ary , res .shape )
144
+ if div :
145
+ dep_evs = _manager .submitted_events
146
+ ht_e7 , d_e2 = tei ._divide_by_scalar (
147
+ src = res , scalar = div , dst = res , sycl_queue = q , depends = dep_evs
148
+ )
149
+ _manager .add_event_pair (ht_e7 , d_e2 )
150
+ return res , [d_e2 ]
151
+
152
+ div = dpt .nan
155
153
dep_evs = _manager .submitted_events
156
- ht_e7 , d_e2 = tei ._divide_inplace (
157
- lhs = res , rhs = div_ary , sycl_queue = q , depends = dep_evs
154
+ ht_e7 , d_e2 = tei ._divide_by_scalar (
155
+ src = res , scalar = div , dst = res , sycl_queue = q , depends = dep_evs
158
156
)
159
157
_manager .add_event_pair (ht_e7 , d_e2 )
160
158
return res , [d_e2 ]
@@ -259,17 +257,9 @@ def mean(x, axis=None, keepdims=False):
259
257
inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
260
258
res = dpt .permute_dims (dpt .reshape (res , res_shape ), inv_perm )
261
259
262
- res_shape = res .shape
263
- # in-place divide
264
- den_dt = dpt .finfo (res_dt ).dtype if res_dt .kind == "c" else res_dt
265
- nelems_arr = dpt .asarray (
266
- nelems , dtype = den_dt , usm_type = res_usm_type , sycl_queue = q
267
- )
268
- if nelems_arr .shape != res_shape :
269
- nelems_arr = dpt .broadcast_to (nelems_arr , res_shape )
270
260
dep_evs = _manager .submitted_events
271
- ht_e2 , div_e = tei ._divide_inplace (
272
- lhs = res , rhs = nelems_arr , sycl_queue = q , depends = dep_evs
261
+ ht_e2 , div_e = tei ._divide_by_scalar (
262
+ src = res , scalar = nelems , dst = res , sycl_queue = q , depends = dep_evs
273
263
)
274
264
_manager .add_event_pair (ht_e2 , div_e )
275
265
return res
0 commit comments