Skip to content

Commit 8f82fe1

Browse files
authored
clip permits None for min and max (#1505)
* Fixes `dpt.copy` returning TypeError instead of raising When provided a non-usm_ndarray-input to copy, copy would return the error instead of raising it * Permits clip arguments `min` and `max` to both be `None` Also resolves gh-1489 * Specify that Python scalars are permitted for `max` and `min` in `clip` * Adds tests to `test_tensor_clip.py` improve `_clip.py` coverage
1 parent 8ed8ef2 commit 8f82fe1

File tree

3 files changed

+187
-21
lines changed

3 files changed

+187
-21
lines changed

dpctl/tensor/_clip.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def _resolve_one_strong_one_weak_types(st_dtype, dtype, dev):
168168
return dpt.dtype(ti.default_device_int_type(dev))
169169
if isinstance(dtype, WeakComplexType):
170170
if st_dtype is dpt.float16 or st_dtype is dpt.float32:
171-
return st_dtype, dpt.complex64
171+
return dpt.complex64
172172
return _to_device_supported_dtype(dpt.complex128, dev)
173-
return (_to_device_supported_dtype(dpt.float64, dev),)
173+
return _to_device_supported_dtype(dpt.float64, dev)
174174
else:
175175
return st_dtype
176176
else:
@@ -197,8 +197,6 @@ def _check_clip_dtypes(res_dtype, arg1_dtype, arg2_dtype, sycl_dev):
197197

198198

199199
def _clip_none(x, val, out, order, _binary_fn):
200-
if order not in ["K", "C", "F", "A"]:
201-
order = "K"
202200
q1, x_usm_type = x.sycl_queue, x.usm_type
203201
q2, val_usm_type = _get_queue_usm_type(val)
204202
if q2 is None:
@@ -391,9 +389,8 @@ def _clip_none(x, val, out, order, _binary_fn):
391389
return out
392390

393391

394-
# need to handle logic for min or max being None
395-
def clip(x, min=None, max=None, out=None, order="K"):
396-
"""clip(x, min, max, out=None, order="K")
392+
def clip(x, /, min=None, max=None, out=None, order="K"):
393+
"""clip(x, min=None, max=None, out=None, order="K")
397394
398395
Clips to the range [`min_i`, `max_i`] for each element `x_i`
399396
in `x`.
@@ -402,14 +399,14 @@ def clip(x, min=None, max=None, out=None, order="K"):
402399
x (usm_ndarray): Array containing elements to clip.
403400
Must be compatible with `min` and `max` according
404401
to broadcasting rules.
405-
min ({None, usm_ndarray}, optional): Array containing minimum values.
402+
min ({None, Union[usm_ndarray, bool, int, float, complex]}, optional):
403+
Array containing minimum values.
406404
Must be compatible with `x` and `max` according
407405
to broadcasting rules.
408-
Only one of `min` and `max` can be `None`.
409-
max ({None, usm_ndarray}, optional): Array containing maximum values.
406+
max ({None, Union[usm_ndarray, bool, int, float, complex]}, optional):
407+
Array containing maximum values.
410408
Must be compatible with `x` and `min` according
411409
to broadcasting rules.
412-
Only one of `min` and `max` can be `None`.
413410
out ({None, usm_ndarray}, optional):
414411
Output array to populate.
415412
Array must have the correct shape and the expected data type.
@@ -428,10 +425,67 @@ def clip(x, min=None, max=None, out=None, order="K"):
428425
"Expected `x` to be of dpctl.tensor.usm_ndarray type, got "
429426
f"{type(x)}"
430427
)
428+
if order not in ["K", "C", "F", "A"]:
429+
order = "K"
431430
if min is None and max is None:
432-
raise ValueError(
433-
"only one of `min` and `max` is permitted to be `None`"
431+
exec_q = x.sycl_queue
432+
orig_out = out
433+
if out is not None:
434+
if not isinstance(out, dpt.usm_ndarray):
435+
raise TypeError(
436+
"output array must be of usm_ndarray type, got "
437+
f"{type(out)}"
438+
)
439+
440+
if out.shape != x.shape:
441+
raise ValueError(
442+
"The shape of input and output arrays are "
443+
f"inconsistent. Expected output shape is {x.shape}, "
444+
f"got {out.shape}"
445+
)
446+
447+
if x.dtype != out.dtype:
448+
raise ValueError(
449+
f"Output array of type {x.dtype} is needed, "
450+
f"got {out.dtype}"
451+
)
452+
453+
if (
454+
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
455+
is None
456+
):
457+
raise ExecutionPlacementError(
458+
"Input and output allocation queues are not compatible"
459+
)
460+
461+
if ti._array_overlap(x, out):
462+
if not ti._same_logical_tensors(x, out):
463+
out = dpt.empty_like(out)
464+
else:
465+
return out
466+
else:
467+
if order == "K":
468+
out = _empty_like_orderK(x, x.dtype)
469+
else:
470+
if order == "A":
471+
order = "F" if x.flags.f_contiguous else "C"
472+
out = dpt.empty_like(x, order=order)
473+
474+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
475+
src=x, dst=out, sycl_queue=exec_q
434476
)
477+
if not (orig_out is None or orig_out is out):
478+
# Copy the out data from temporary buffer to original memory
479+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
480+
src=out,
481+
dst=orig_out,
482+
sycl_queue=exec_q,
483+
depends=[copy_ev],
484+
)
485+
ht_copy_out_ev.wait()
486+
out = orig_out
487+
ht_copy_ev.wait()
488+
return out
435489
elif max is None:
436490
return _clip_none(x, min, out, order, tei._maximum)
437491
elif min is None:

dpctl/tensor/_copy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def copy(usm_ary, order="K"):
531531
)
532532
order = order[0].upper()
533533
if not isinstance(usm_ary, dpt.usm_ndarray):
534-
return TypeError(
534+
raise TypeError(
535535
f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}"
536536
)
537537
copy_order = "C"

dpctl/tests/test_tensor_clip.py

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24-
from dpctl.tensor._type_utils import _can_cast
24+
from dpctl.tensor._elementwise_common import _get_dtype
25+
from dpctl.tensor._type_utils import (
26+
_can_cast,
27+
_strong_dtype_num_kind,
28+
_weak_type_num_kind,
29+
)
2530
from dpctl.utils import ExecutionPlacementError
2631

2732
_all_dtypes = [
@@ -194,6 +199,15 @@ def test_clip_out_need_temporary():
194199
dpt.clip(x[:6], 2, 3, out=x[-6:])
195200
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
196201

202+
x = dpt.arange(12, dtype="i4")
203+
dpt.clip(x[:6], out=x[-6:])
204+
expected = dpt.arange(6, dtype="i4")
205+
assert dpt.all(x[:-6] == expected) and dpt.all(x[-6:] == expected)
206+
207+
x = dpt.ones(10, dtype="i4")
208+
dpt.clip(x, out=x)
209+
assert dpt.all(x == 1)
210+
197211
x = dpt.full(6, 3, dtype="i4")
198212
a_min = dpt.full(10, 2, dtype="i4")
199213
a_max = dpt.asarray(4, dtype="i4")
@@ -227,6 +241,21 @@ def test_clip_arg_validation():
227241
with pytest.raises(TypeError):
228242
dpt.clip(check, x1, x2)
229243

244+
with pytest.raises(ValueError):
245+
dpt.clip(x1, check, x2)
246+
247+
with pytest.raises(ValueError):
248+
dpt.clip(x1, check)
249+
250+
with pytest.raises(TypeError):
251+
dpt.clip(x1, x1, x2, out=check)
252+
253+
with pytest.raises(TypeError):
254+
dpt.clip(x1, x2, out=check)
255+
256+
with pytest.raises(TypeError):
257+
dpt.clip(x1, out=check)
258+
230259

231260
@pytest.mark.parametrize(
232261
"dt1,dt2", [("i4", "i4"), ("i4", "i2"), ("i2", "i4"), ("i1", "i2")]
@@ -599,22 +628,40 @@ def test_clip_max_less_than_min():
599628
assert dpt.all(res == 0)
600629

601630

602-
def test_clip_minmax_weak_types():
631+
@pytest.mark.parametrize("dt", ["?", "i4", "f4", "c8"])
632+
def test_clip_minmax_weak_types(dt):
603633
get_queue_or_skip()
604634

605-
x = dpt.zeros(10, dtype=dpt.bool)
635+
x = dpt.zeros(10, dtype=dt)
606636
min_list = [False, 0, 0.0, 0.0 + 0.0j]
607637
max_list = [True, 1, 1.0, 1.0 + 0.0j]
638+
608639
for min_v, max_v in zip(min_list, max_list):
609-
if isinstance(min_v, bool) and isinstance(max_v, bool):
610-
y = dpt.clip(x, min_v, max_v)
611-
assert isinstance(y, dpt.usm_ndarray)
640+
st_dt = _strong_dtype_num_kind(dpt.dtype(dt))
641+
wk_dt1 = _weak_type_num_kind(_get_dtype(min_v, x.sycl_device))
642+
wk_dt2 = _weak_type_num_kind(_get_dtype(max_v, x.sycl_device))
643+
644+
if st_dt >= wk_dt1 and st_dt >= wk_dt2:
645+
r = dpt.clip(x, min_v, max_v)
646+
assert isinstance(r, dpt.usm_ndarray)
612647
else:
613648
with pytest.raises(ValueError):
614649
dpt.clip(x, min_v, max_v)
615650

651+
if st_dt >= wk_dt1:
652+
r = dpt.clip(x, min_v)
653+
assert isinstance(r, dpt.usm_ndarray)
654+
655+
r = dpt.clip(x, None, min_v)
656+
assert isinstance(r, dpt.usm_ndarray)
657+
else:
658+
with pytest.raises(ValueError):
659+
dpt.clip(x, min_v)
660+
with pytest.raises(ValueError):
661+
dpt.clip(x, None, max_v)
662+
616663

617-
def test_clip_max_weak_types():
664+
def test_clip_max_weak_type_errors():
618665
get_queue_or_skip()
619666

620667
x = dpt.zeros(10, dtype="i4")
@@ -626,6 +673,15 @@ def test_clip_max_weak_types():
626673
with pytest.raises(ValueError):
627674
dpt.clip(x, 2.5, m)
628675

676+
with pytest.raises(ValueError):
677+
dpt.clip(x, 2.5)
678+
679+
with pytest.raises(ValueError):
680+
dpt.clip(dpt.astype(x, "?"), 2)
681+
682+
with pytest.raises(ValueError):
683+
dpt.clip(dpt.astype(x, "f4"), complex(2))
684+
629685

630686
def test_clip_unaligned():
631687
get_queue_or_skip()
@@ -636,3 +692,59 @@ def test_clip_unaligned():
636692

637693
expected = dpt.full(512, 2, dtype="i4")
638694
assert dpt.all(dpt.clip(x[1:], a_min, a_max) == expected)
695+
696+
697+
def test_clip_none_args():
698+
get_queue_or_skip()
699+
700+
x = dpt.arange(10, dtype="i4")
701+
r = dpt.clip(x)
702+
assert dpt.all(x == r)
703+
704+
705+
def test_clip_shape_errors():
706+
get_queue_or_skip()
707+
708+
x = dpt.ones((4, 4), dtype="i4")
709+
a_min = dpt.ones(5, dtype="i4")
710+
a_max = dpt.ones(5, dtype="i4")
711+
712+
with pytest.raises(ValueError):
713+
dpt.clip(x, a_min, a_max)
714+
715+
with pytest.raises(ValueError):
716+
dpt.clip(x, a_min)
717+
718+
with pytest.raises(ValueError):
719+
dpt.clip(x, 0, 1, out=a_min)
720+
721+
with pytest.raises(ValueError):
722+
dpt.clip(x, 0, out=a_min)
723+
724+
with pytest.raises(ValueError):
725+
dpt.clip(x, out=a_min)
726+
727+
728+
def test_clip_compute_follows_data():
729+
q1 = get_queue_or_skip()
730+
q2 = get_queue_or_skip()
731+
732+
x = dpt.ones(10, dtype="i4", sycl_queue=q1)
733+
a_min = dpt.ones(10, dtype="i4", sycl_queue=q2)
734+
a_max = dpt.ones(10, dtype="i4", sycl_queue=q1)
735+
res = dpt.empty_like(x, sycl_queue=q2)
736+
737+
with pytest.raises(ExecutionPlacementError):
738+
dpt.clip(x, a_min, a_max)
739+
740+
with pytest.raises(ExecutionPlacementError):
741+
dpt.clip(x, dpt.ones_like(x), a_max, out=res)
742+
743+
with pytest.raises(ExecutionPlacementError):
744+
dpt.clip(x, a_min)
745+
746+
with pytest.raises(ExecutionPlacementError):
747+
dpt.clip(x, None, a_max, out=res)
748+
749+
with pytest.raises(ExecutionPlacementError):
750+
dpt.clip(x, out=res)

0 commit comments

Comments
 (0)