@@ -192,7 +192,7 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
192
192
return None , None , None
193
193
194
194
195
- def _clip_none_call (x , val , out , order , _binary_fn ):
195
+ def _clip_none (x , val , out , order , _binary_fn ):
196
196
if order not in ["K" , "C" , "F" , "A" ]:
197
197
order = "K"
198
198
q1 , x_usm_type = x .sycl_queue , x .usm_type
@@ -429,9 +429,9 @@ def clip(x, min=None, max=None, out=None, order="K"):
429
429
"only one of `min` and `max` is permitted to be `None`"
430
430
)
431
431
elif max is None :
432
- return _clip_none_call (x , min , out , order , ti ._maximum )
432
+ return _clip_none (x , min , out , order , ti ._maximum )
433
433
elif min is None :
434
- return _clip_none_call (x , max , out , order , ti ._minimum )
434
+ return _clip_none (x , max , out , order , ti ._minimum )
435
435
else :
436
436
q1 , x_usm_type = x .sycl_queue , x .usm_type
437
437
q2 , min_usm_type = _get_queue_usm_type (min )
@@ -646,12 +646,23 @@ def clip(x, min=None, max=None, out=None, order="K"):
646
646
out = orig_out
647
647
ht_binary_ev .wait ()
648
648
return out
649
+
649
650
elif buf1_dt is None :
650
651
if order == "K" :
651
652
buf2 = _empty_like_orderK (a_max , buf2_dt )
652
653
else :
653
654
if order == "A" :
654
- order = "F" if a_min .flags .f_contiguous else "C"
655
+ order = (
656
+ "F"
657
+ if all (
658
+ arr .flags .f_contiguous
659
+ for arr in (
660
+ x ,
661
+ a_min ,
662
+ )
663
+ )
664
+ else "C"
665
+ )
655
666
buf2 = dpt .empty_like (a_max , dtype = buf2_dt , order = order )
656
667
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
657
668
src = a_max , dst = buf2 , sycl_queue = exec_q
@@ -701,12 +712,23 @@ def clip(x, min=None, max=None, out=None, order="K"):
701
712
ht_copy_ev .wait ()
702
713
ht_binary_ev .wait ()
703
714
return out
715
+
704
716
elif buf2_dt is None :
705
717
if order == "K" :
706
718
buf1 = _empty_like_orderK (a_min , buf1_dt )
707
719
else :
708
720
if order == "A" :
709
- order = "F" if a_min .flags .f_contiguous else "C"
721
+ order = (
722
+ "F"
723
+ if all (
724
+ arr .flags .f_contiguous
725
+ for arr in (
726
+ x ,
727
+ a_max ,
728
+ )
729
+ )
730
+ else "C"
731
+ )
710
732
buf1 = dpt .empty_like (a_min , dtype = buf1_dt , order = order )
711
733
ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
712
734
src = a_min , dst = buf1 , sycl_queue = exec_q
@@ -758,9 +780,17 @@ def clip(x, min=None, max=None, out=None, order="K"):
758
780
return out
759
781
760
782
if order in ["K" , "A" ]:
761
- if a_min .flags .f_contiguous and a_max .flags .f_contiguous :
783
+ if (
784
+ x .flags .f_contiguous
785
+ and a_min .flags .f_contiguous
786
+ and a_max .flags .f_contiguous
787
+ ):
762
788
order = "F"
763
- elif a_min .flags .c_contiguous and a_max .flags .c_contiguous :
789
+ elif (
790
+ x .flags .c_contiguous
791
+ and a_min .flags .c_contiguous
792
+ and a_max .flags .c_contiguous
793
+ ):
764
794
order = "C"
765
795
else :
766
796
order = "C" if order == "A" else "K"
0 commit comments