Skip to content

Commit 07013c3

Browse files
committed
Removes dead branch from _accumulators.py
A second out temporary does not need to be made in either branch when input and requested dtype are not implemented, as temporaries are always made Also removes part of a test intended to reach this branch
1 parent 53eabe6 commit 07013c3

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

dpctl/tensor/_accumulation.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,6 @@ def _accumulate_common(
189189
out = orig_out
190190
else:
191191
if _dtype_supported(res_dt, res_dt):
192-
# no need to check for orig_out here, branch should never
193-
# be reached if inp and out are the same array
194192
tmp = dpt.empty(
195193
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
196194
)
@@ -229,31 +227,25 @@ def _accumulate_common(
229227
host_tasks_list.append(ht_e_cpy)
230228
if not include_initial:
231229
ht_e, a_e = _accumulate_fn(
232-
src=arr,
230+
src=tmp,
233231
trailing_dims_to_accumulate=1,
234232
dst=tmp_res,
235233
sycl_queue=q,
236234
depends=[cpy_e],
237235
)
238236
else:
239237
ht_e, a_e = _accumulate_include_initial_fn(
240-
src=arr,
238+
src=tmp,
241239
dst=tmp_res,
242240
sycl_queue=q,
243241
depends=[cpy_e],
244242
)
245243
host_tasks_list.append(ht_e)
246-
ht_e_cpy2, cpy_e2 = ti._copy_usm_ndarray_into_usm_ndarray(
244+
ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
247245
src=tmp_res, dst=out, sycl_queue=q, depends=[a_e]
248246
)
249247
host_tasks_list.append(ht_e_cpy2)
250-
if not (orig_out is None or out is orig_out):
251-
# Copy the out data from temporary buffer to original memory
252-
ht_e_cpy3, _ = ti._copy_usm_ndarray_into_usm_ndarray(
253-
src=out, dst=orig_out, sycl_queue=q, depends=[cpy_e2]
254-
)
255-
host_tasks_list.append(ht_e_cpy3)
256-
out = orig_out
248+
257249
if appended_axis:
258250
out = dpt.squeeze(out)
259251
if a1 != nd:

dpctl/tests/test_tensor_accumulation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def test_accumulator_out_kwarg():
231231
out = dpt.empty_like(x, dtype=default_int)
232232
dpt.cumulative_sum(x, out=out)
233233
assert dpt.all(expected == out)
234+
234235
# overlap
235236
x = dpt.ones(n, dtype=default_int, sycl_queue=q)
236237
dpt.cumulative_sum(x, out=x)
@@ -252,12 +253,6 @@ def test_accumulator_out_kwarg():
252253
dpt.cumulative_sum(x, out=out)
253254
assert expected == out
254255

255-
# overlapping and unimplemented
256-
x = dpt.ones(n, dtype="?", sycl_queue=q)
257-
x[20:] = False
258-
dpt.cumulative_sum(x, dtype="?", out=x)
259-
assert dpt.all(x)
260-
261256

262257
def test_accumulator_arg_validation():
263258
q1 = get_queue_or_skip()

0 commit comments

Comments
 (0)