Skip to content

Commit f69dedf

Browse files
committed
Removed redundant branches in clip and elementwise function calls
As the result dtype of the out array is already checked when overlap is checked, checking again later is superfluous
1 parent 7f369b0 commit f69dedf

File tree

2 files changed

+3
-18
lines changed

2 files changed

+3
-18
lines changed

dpctl/tensor/_clip.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,7 @@ def _clip_none_call(x, val, out, order, _binary_fn):
374374
sycl_queue=exec_q,
375375
order=order,
376376
)
377-
else:
378-
if res_dt != out.dtype:
379-
raise TypeError(
380-
f"Output array of type {res_dt} is needed,"
381-
f"got {out.dtype}"
382-
)
377+
383378
if x_shape != res_shape:
384379
x = dpt.broadcast_to(x, res_shape)
385380
buf = dpt.broadcast_to(buf, res_shape)
@@ -696,12 +691,7 @@ def clip(x, min=None, max=None, out=None, order="K"):
696691
sycl_queue=exec_q,
697692
order=order,
698693
)
699-
else:
700-
if res_dt != out.dtype:
701-
raise TypeError(
702-
f"Output array of type {res_dt} is needed, "
703-
f"got {out.dtype}"
704-
)
694+
705695
x = dpt.broadcast_to(x, res_shape)
706696
if a_min.shape != res_shape:
707697
a_min = dpt.broadcast_to(a_min, res_shape)

dpctl/tensor/_elementwise_common.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -649,12 +649,7 @@ def __call__(self, o1, o2, out=None, order="K"):
649649
sycl_queue=exec_q,
650650
order=order,
651651
)
652-
else:
653-
if res_dt != out.dtype:
654-
raise TypeError(
655-
f"Output array of type {res_dt} is needed,"
656-
f"got {out.dtype}"
657-
)
652+
658653
if src1.shape != res_shape:
659654
src1 = dpt.broadcast_to(src1, res_shape)
660655
buf2 = dpt.broadcast_to(buf2, res_shape)

0 commit comments

Comments
 (0)