Skip to content

Commit 64b83bb

Browse files
committed
Removes a dead branch from _accumulate_common
As `out` and the input would have to have the same data type to overlap, the second branch is never reached if `out` is the same array as the input
1 parent 441e081 commit 64b83bb

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def _accumulate_common(
189189
out = orig_out
190190
else:
191191
if _dtype_supported(res_dt, res_dt):
192+
# no need to check for orig_out here, branch should never
193+
# be reached if inp and out are the same array
192194
tmp = dpt.empty(
193195
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
194196
)
@@ -211,14 +213,6 @@ def _accumulate_common(
211213
sycl_queue=q,
212214
depends=[cpy_e],
213215
)
214-
host_tasks_list.append(ht_e)
215-
if not (orig_out is None or out is orig_out):
216-
# Copy the out data from temporary buffer to original memory
217-
ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
218-
src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev]
219-
)
220-
host_tasks_list.append(ht_e_cpy2)
221-
out = orig_out
222216
else:
223217
buf_dt = _default_accumulation_type_fn(inp_dt, q)
224218
tmp = dpt.empty(

0 commit comments

Comments
 (0)