Skip to content

Commit 9c5694b

Browse files
committed
Fixed order logic in clip
Now properly accounts for all three arrays in all branches
1 parent 9096243 commit 9c5694b

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

dpctl/tensor/_clip.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
192192
return None, None, None
193193

194194

195-
def _clip_none_call(x, val, out, order, _binary_fn):
195+
def _clip_none(x, val, out, order, _binary_fn):
196196
if order not in ["K", "C", "F", "A"]:
197197
order = "K"
198198
q1, x_usm_type = x.sycl_queue, x.usm_type
@@ -429,9 +429,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
429429
"only one of `min` and `max` is permitted to be `None`"
430430
)
431431
elif max is None:
432-
return _clip_none_call(x, min, out, order, ti._maximum)
432+
return _clip_none(x, min, out, order, ti._maximum)
433433
elif min is None:
434-
return _clip_none_call(x, max, out, order, ti._minimum)
434+
return _clip_none(x, max, out, order, ti._minimum)
435435
else:
436436
q1, x_usm_type = x.sycl_queue, x.usm_type
437437
q2, min_usm_type = _get_queue_usm_type(min)
@@ -646,12 +646,23 @@ def clip(x, min=None, max=None, out=None, order="K"):
646646
out = orig_out
647647
ht_binary_ev.wait()
648648
return out
649+
649650
elif buf1_dt is None:
650651
if order == "K":
651652
buf2 = _empty_like_orderK(a_max, buf2_dt)
652653
else:
653654
if order == "A":
654-
order = "F" if a_min.flags.f_contiguous else "C"
655+
order = (
656+
"F"
657+
if all(
658+
arr.flags.f_contiguous
659+
for arr in (
660+
x,
661+
a_min,
662+
)
663+
)
664+
else "C"
665+
)
655666
buf2 = dpt.empty_like(a_max, dtype=buf2_dt, order=order)
656667
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
657668
src=a_max, dst=buf2, sycl_queue=exec_q
@@ -701,12 +712,23 @@ def clip(x, min=None, max=None, out=None, order="K"):
701712
ht_copy_ev.wait()
702713
ht_binary_ev.wait()
703714
return out
715+
704716
elif buf2_dt is None:
705717
if order == "K":
706718
buf1 = _empty_like_orderK(a_min, buf1_dt)
707719
else:
708720
if order == "A":
709-
order = "F" if a_min.flags.f_contiguous else "C"
721+
order = (
722+
"F"
723+
if all(
724+
arr.flags.f_contiguous
725+
for arr in (
726+
x,
727+
a_max,
728+
)
729+
)
730+
else "C"
731+
)
710732
buf1 = dpt.empty_like(a_min, dtype=buf1_dt, order=order)
711733
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
712734
src=a_min, dst=buf1, sycl_queue=exec_q
@@ -758,9 +780,17 @@ def clip(x, min=None, max=None, out=None, order="K"):
758780
return out
759781

760782
if order in ["K", "A"]:
761-
if a_min.flags.f_contiguous and a_max.flags.f_contiguous:
783+
if (
784+
x.flags.f_contiguous
785+
and a_min.flags.f_contiguous
786+
and a_max.flags.f_contiguous
787+
):
762788
order = "F"
763-
elif a_min.flags.c_contiguous and a_max.flags.c_contiguous:
789+
elif (
790+
x.flags.c_contiguous
791+
and a_min.flags.c_contiguous
792+
and a_max.flags.c_contiguous
793+
):
764794
order = "C"
765795
else:
766796
order = "C" if order == "A" else "K"

0 commit comments

Comments
 (0)