From 67fbdc31c8bfce73337f35094868636c1f0416d2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 21 Feb 2025 13:28:19 -0800 Subject: [PATCH 1/2] Clean up code in _accumulation.py Renamed events Remove unnecessary instantiation of a SyclEvent --- dpctl/tensor/_accumulation.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/dpctl/tensor/_accumulation.py b/dpctl/tensor/_accumulation.py index 4647073203..46b99e512d 100644 --- a/dpctl/tensor/_accumulation.py +++ b/dpctl/tensor/_accumulation.py @@ -125,7 +125,6 @@ def _accumulate_common( if a1 != nd: out = dpt.permute_dims(out, perm) - final_ev = dpctl.SyclEvent() _manager = SequentialOrderManager[q] depends = _manager.submitted_events if implemented_types: @@ -144,12 +143,11 @@ def _accumulate_common( _manager.add_event_pair(ht_e, acc_ev) if not (orig_out is None or out is orig_out): # Copy the out data from temporary buffer to original memory - ht_e_cpy, acc_ev = ti._copy_usm_ndarray_into_usm_ndarray( + ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev] ) - _manager.add_event_pair(ht_e_cpy, acc_ev) + _manager.add_event_pair(ht_e_cpy, cpy_ev) out = orig_out - final_ev = acc_ev else: if _dtype_supported(res_dt, res_dt): tmp = dpt.empty( @@ -160,7 +158,7 @@ def _accumulate_common( ) _manager.add_event_pair(ht_e_cpy, cpy_e) if not include_initial: - ht_e, final_ev = _accumulate_fn( + ht_e, acc_ev = _accumulate_fn( src=tmp, trailing_dims_to_accumulate=1, dst=out, @@ -168,13 +166,13 @@ def _accumulate_common( depends=[cpy_e], ) else: - ht_e, final_ev = _accumulate_include_initial_fn( + ht_e, acc_ev = _accumulate_include_initial_fn( src=tmp, dst=out, sycl_queue=q, depends=[cpy_e], ) - _manager.add_event_pair(ht_e, final_ev) + _manager.add_event_pair(ht_e, acc_ev) else: buf_dt = _default_accumulation_type_fn(inp_dt, q) tmp = dpt.empty( @@ -190,7 +188,7 @@ def _accumulate_common( if a1 != nd: tmp_res = dpt.permute_dims(tmp_res, perm) if not include_initial: - ht_e, a_e = _accumulate_fn( + ht_e, acc_ev = _accumulate_fn( src=tmp, trailing_dims_to_accumulate=1, dst=tmp_res, @@ -198,17 +196,17 @@ def _accumulate_common( depends=[cpy_e], ) else: - ht_e, a_e = _accumulate_include_initial_fn( + ht_e, acc_ev = _accumulate_include_initial_fn( src=tmp, dst=tmp_res, sycl_queue=q, depends=[cpy_e], ) - _manager.add_event_pair(ht_e, a_e) - ht_e_cpy2, final_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=tmp_res, dst=out, sycl_queue=q, depends=[a_e] + _manager.add_event_pair(ht_e, acc_ev) + ht_e_cpy2, cpy_ev2 = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp_res, dst=out, sycl_queue=q, depends=[acc_ev] ) - _manager.add_event_pair(ht_e_cpy2, final_ev) + _manager.add_event_pair(ht_e_cpy2, cpy_ev2) if appended_axis: out = dpt.squeeze(out) From cd2bd1a878fe06e88b2484e05578582d7ecb6aab Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 25 Feb 2025 19:39:30 -0800 Subject: [PATCH 2/2] Rename variables in _accumulate_common --- dpctl/tensor/_accumulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dpctl/tensor/_accumulation.py b/dpctl/tensor/_accumulation.py index 46b99e512d..1006d222b9 100644 --- a/dpctl/tensor/_accumulation.py +++ b/dpctl/tensor/_accumulation.py @@ -143,10 +143,10 @@ def _accumulate_common( _manager.add_event_pair(ht_e, acc_ev) if not (orig_out is None or out is orig_out): # Copy the out data from temporary buffer to original memory - ht_e_cpy, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev] ) - _manager.add_event_pair(ht_e_cpy, cpy_ev) + _manager.add_event_pair(ht_e_cpy, cpy_e) out = orig_out else: if _dtype_supported(res_dt, res_dt): @@ -203,10 +203,10 @@ def _accumulate_common( depends=[cpy_e], ) _manager.add_event_pair(ht_e, acc_ev) - ht_e_cpy2, cpy_ev2 = ti._copy_usm_ndarray_into_usm_ndarray( + ht_e_cpy2, cpy_e2 = ti._copy_usm_ndarray_into_usm_ndarray( src=tmp_res, dst=out, sycl_queue=q, depends=[acc_ev] ) - _manager.add_event_pair(ht_e_cpy2, cpy_ev2) + _manager.add_event_pair(ht_e_cpy2, cpy_e2) if appended_axis: out = dpt.squeeze(out)