Skip to content

Commit 92367ad

Browse files
Deploy _divide_by_scalar in statistical functions
This brings significant performance improvement for Lloyd algorithm benchmark.
1 parent 36ebb57 commit 92367ad

File tree

1 file changed

+17
-27
lines changed

1 file changed

+17
-27
lines changed

dpctl/tensor/_statistical_functions.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,13 @@ def _var_impl(x, axis, correction, keepdims):
9393
)
9494
# divide in-place to get mean
9595
mean_ary_shape = mean_ary.shape
96-
nelems_ary = dpt.asarray(
97-
nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
98-
)
99-
if nelems_ary.shape != mean_ary_shape:
100-
nelems_ary = dpt.broadcast_to(nelems_ary, mean_ary_shape)
96+
10197
dep_evs = _manager.submitted_events
102-
ht_e2, d_e1 = tei._divide_inplace(
103-
lhs=mean_ary, rhs=nelems_ary, sycl_queue=q, depends=dep_evs
98+
ht_e2, d_e1 = tei._divide_by_scalar(
99+
src=mean_ary, scalar=nelems, dst=mean_ary, sycl_queue=q, depends=dep_evs
104100
)
105101
_manager.add_event_pair(ht_e2, d_e1)
102+
106103
# subtract mean from original array to get deviations
107104
dev_ary = dpt.empty_like(buf)
108105
if mean_ary_shape != buf.shape:
@@ -144,17 +141,18 @@ def _var_impl(x, axis, correction, keepdims):
144141
res_shape = res.shape
145142
# when nelems - correction <= 0, yield nans
146143
div = max(nelems - correction, 0)
147-
if not div:
148-
div = dpt.nan
149-
div_ary = dpt.asarray(
150-
div, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
151-
)
152-
# divide in-place again
153-
if div_ary.shape != res_shape:
154-
div_ary = dpt.broadcast_to(div_ary, res.shape)
144+
if div:
145+
dep_evs = _manager.submitted_events
146+
ht_e7, d_e2 = tei._divide_by_scalar(
147+
src=res, scalar=div, dst=res, sycl_queue=q, depends=dep_evs
148+
)
149+
_manager.add_event_pair(ht_e7, d_e2)
150+
return res, [d_e2]
151+
152+
div = dpt.nan
155153
dep_evs = _manager.submitted_events
156-
ht_e7, d_e2 = tei._divide_inplace(
157-
lhs=res, rhs=div_ary, sycl_queue=q, depends=dep_evs
154+
ht_e7, d_e2 = tei._divide_by_scalar(
155+
src=res, scalar=div, dst=res, sycl_queue=q, depends=dep_evs
158156
)
159157
_manager.add_event_pair(ht_e7, d_e2)
160158
return res, [d_e2]
@@ -259,17 +257,9 @@ def mean(x, axis=None, keepdims=False):
259257
inv_perm = sorted(range(nd), key=lambda d: perm[d])
260258
res = dpt.permute_dims(dpt.reshape(res, res_shape), inv_perm)
261259

262-
res_shape = res.shape
263-
# in-place divide
264-
den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt
265-
nelems_arr = dpt.asarray(
266-
nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q
267-
)
268-
if nelems_arr.shape != res_shape:
269-
nelems_arr = dpt.broadcast_to(nelems_arr, res_shape)
270260
dep_evs = _manager.submitted_events
271-
ht_e2, div_e = tei._divide_inplace(
272-
lhs=res, rhs=nelems_arr, sycl_queue=q, depends=dep_evs
261+
ht_e2, div_e = tei._divide_by_scalar(
262+
src=res, scalar=nelems, dst=res, sycl_queue=q, depends=dep_evs
273263
)
274264
_manager.add_event_pair(ht_e2, div_e)
275265
return res

0 commit comments

Comments
 (0)