Skip to content

Commit 7f369b0

Browse files
committed
Adds more tests for clip
1 parent 0df7721 commit 7f369b0

File tree

2 files changed

+93
-32
lines changed

2 files changed

+93
-32
lines changed

dpctl/tensor/_clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def _clip_none_call(x, val, out, order, _binary_fn):
281281

282282
if res_dt != out.dtype:
283283
raise TypeError(
284-
f"Output array of type {res_dt} is needed," f"got {out.dtype}"
284+
f"Output array of type {res_dt} is needed, got {out.dtype}"
285285
)
286286

287287
if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None:
@@ -699,7 +699,7 @@ def clip(x, min=None, max=None, out=None, order="K"):
699699
else:
700700
if res_dt != out.dtype:
701701
raise TypeError(
702-
f"Output array of type {res_dt} is needed,"
702+
f"Output array of type {res_dt} is needed, "
703703
f"got {out.dtype}"
704704
)
705705
x = dpt.broadcast_to(x, res_shape)

dpctl/tests/test_tensor_clip.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def test_clip_python_scalars():
129129
for zero, one, arr in zip(py_zeros, py_ones, arrs):
130130
r = dpt.clip(arr, zero, one)
131131
assert isinstance(r, dpt.usm_ndarray)
132+
r = dpt.clip(arr, min=zero)
133+
assert isinstance(r, dpt.usm_ndarray)
132134

133135

134136
def test_clip_in_place():
@@ -169,14 +171,46 @@ def test_clip_out_need_temporary():
169171
get_queue_or_skip()
170172

171173
x = dpt.ones(10, dtype="i4")
174+
a_min = dpt.asarray(2, dtype="i4")
175+
a_max = dpt.asarray(3, dtype="i4")
176+
dpt.clip(x[:6], 2, 3, out=x[-6:])
177+
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
178+
179+
x = dpt.ones(10, dtype="i4")
180+
a_min = dpt.asarray(2, dtype="i4")
181+
a_max = dpt.asarray(3, dtype="i2")
182+
dpt.clip(x[:6], 2, 3, out=x[-6:])
183+
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
184+
185+
x = dpt.ones(10, dtype="i4")
186+
a_min = dpt.asarray(2, dtype="i2")
187+
a_max = dpt.asarray(3, dtype="i4")
188+
dpt.clip(x[:6], 2, 3, out=x[-6:])
189+
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
190+
191+
x = dpt.ones(10, dtype="i4")
192+
a_min = dpt.asarray(2, dtype="i2")
193+
a_max = dpt.asarray(3, dtype="i1")
172194
dpt.clip(x[:6], 2, 3, out=x[-6:])
173195
assert dpt.all(x[:-6] == 1) and dpt.all(x[-6:] == 2)
174196

175197
x = dpt.full(6, 3, dtype="i4")
176198
a_min = dpt.full(10, 2, dtype="i4")
177-
dpt.clip(x, min=a_min[:6], max=4, out=a_min[-6:])
199+
a_max = dpt.asarray(4, dtype="i4")
200+
dpt.clip(x, min=a_min[:6], max=a_max, out=a_min[-6:])
178201
assert dpt.all(a_min[:-6] == 2) and dpt.all(a_min[-6:] == 3)
179202

203+
x = dpt.full(6, 3, dtype="i4")
204+
a_min = dpt.full(10, 2, dtype="i4")
205+
a_max = dpt.asarray(4, dtype="i2")
206+
dpt.clip(x, min=a_min[:6], max=a_max, out=a_min[-6:])
207+
assert dpt.all(a_min[:-6] == 2) and dpt.all(a_min[-6:] == 3)
208+
209+
210+
def test_clip_out_need_temporary_none():
211+
get_queue_or_skip()
212+
213+
x = dpt.full(6, 3, dtype="i4")
180214
# with min/max == None
181215
a_min = dpt.full(10, 2, dtype="i4")
182216
dpt.clip(x, min=a_min[:6], max=None, out=a_min[-6:])
@@ -198,7 +232,10 @@ def test_where_arg_validation():
198232
dpt.where(x1, x2, check)
199233

200234

201-
def test_clip_order():
235+
@pytest.mark.parametrize(
236+
"dt1,dt2", [("i4", "i4"), ("i4", "i2"), ("i2", "i4"), ("i1", "i2")]
237+
)
238+
def test_clip_order(dt1, dt2):
202239
get_queue_or_skip()
203240

204241
test_shape = (
@@ -209,8 +246,8 @@ def test_clip_order():
209246
n = test_shape[-1]
210247

211248
ar1 = dpt.ones(test_shape, dtype="i4", order="C")
212-
ar2 = dpt.ones(test_shape, dtype="i4", order="C")
213-
ar3 = dpt.ones(test_shape, dtype="i4", order="C")
249+
ar2 = dpt.ones(test_shape, dtype=dt1, order="C")
250+
ar3 = dpt.ones(test_shape, dtype=dt2, order="C")
214251
r1 = dpt.clip(ar1, ar2, ar3, order="C")
215252
assert r1.flags.c_contiguous
216253
r2 = dpt.clip(ar1, ar2, ar3, order="F")
@@ -220,6 +257,49 @@ def test_clip_order():
220257
r4 = dpt.clip(ar1, ar2, ar3, order="K")
221258
assert r4.flags.c_contiguous
222259

260+
ar1 = dpt.ones(test_shape, dtype="i4", order="F")
261+
ar2 = dpt.ones(test_shape, dtype=dt1, order="F")
262+
ar3 = dpt.ones(test_shape, dtype=dt2, order="F")
263+
r1 = dpt.clip(ar1, ar2, ar3, order="C")
264+
assert r1.flags.c_contiguous
265+
r2 = dpt.clip(ar1, ar2, ar3, order="F")
266+
assert r2.flags.f_contiguous
267+
r3 = dpt.clip(ar1, ar2, ar3, order="A")
268+
assert r3.flags.f_contiguous
269+
r4 = dpt.clip(ar1, ar2, ar3, order="K")
270+
assert r4.flags.f_contiguous
271+
272+
ar1 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2]
273+
ar2 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2]
274+
ar3 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2]
275+
r4 = dpt.clip(ar1, ar2, ar3, order="K")
276+
assert r4.strides == (n, -1)
277+
r5 = dpt.clip(ar1, ar2, ar3, order="C")
278+
assert r5.strides == (n, 1)
279+
280+
ar1 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2].mT
281+
ar2 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2].mT
282+
ar3 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2].mT
283+
r4 = dpt.clip(ar1, ar2, ar3, order="K")
284+
assert r4.strides == (-1, n)
285+
r5 = dpt.clip(ar1, ar2, ar3, order="C")
286+
assert r5.strides == (n, 1)
287+
288+
289+
@pytest.mark.parametrize("dt", ["i4", "i2"])
290+
def test_clip_none_order(dt):
291+
get_queue_or_skip()
292+
293+
test_shape = (
294+
20,
295+
20,
296+
)
297+
test_shape2 = tuple(2 * dim for dim in test_shape)
298+
n = test_shape[-1]
299+
300+
ar1 = dpt.ones(test_shape, dtype="i4", order="C")
301+
ar2 = dpt.ones(test_shape, dtype=dt, order="C")
302+
223303
r1 = dpt.clip(ar1, min=None, max=ar2, order="C")
224304
assert r1.flags.c_contiguous
225305
r2 = dpt.clip(ar1, min=None, max=ar2, order="F")
@@ -230,16 +310,7 @@ def test_clip_order():
230310
assert r4.flags.c_contiguous
231311

232312
ar1 = dpt.ones(test_shape, dtype="i4", order="F")
233-
ar2 = dpt.ones(test_shape, dtype="i4", order="F")
234-
ar3 = dpt.ones(test_shape, dtype="i4", order="F")
235-
r1 = dpt.clip(ar1, ar2, ar3, order="C")
236-
assert r1.flags.c_contiguous
237-
r2 = dpt.clip(ar1, ar2, ar3, order="F")
238-
assert r2.flags.f_contiguous
239-
r3 = dpt.clip(ar1, ar2, ar3, order="A")
240-
assert r3.flags.f_contiguous
241-
r4 = dpt.clip(ar1, ar2, ar3, order="K")
242-
assert r4.flags.f_contiguous
313+
ar2 = dpt.ones(test_shape, dtype=dt, order="F")
243314

244315
r1 = dpt.clip(ar1, min=None, max=ar2, order="C")
245316
assert r1.flags.c_contiguous
@@ -251,29 +322,19 @@ def test_clip_order():
251322
assert r4.flags.f_contiguous
252323

253324
ar1 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2]
254-
ar2 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2]
255-
ar3 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2]
256-
r4 = dpt.clip(ar1, ar2, ar3, order="K")
257-
assert r4.strides == (n, -1)
258-
r5 = dpt.clip(ar1, ar2, ar3, order="C")
259-
assert r5.strides == (n, 1)
325+
ar2 = dpt.ones(test_shape2, dtype=dt, order="C")[:20, ::-2]
260326

261-
r4 = dpt.clip(ar1, min=None, max=ar3, order="K")
327+
r4 = dpt.clip(ar1, min=None, max=ar2, order="K")
262328
assert r4.strides == (n, -1)
263-
r5 = dpt.clip(ar1, min=None, max=ar3, order="C")
329+
r5 = dpt.clip(ar1, min=None, max=ar2, order="C")
264330
assert r5.strides == (n, 1)
265331

266332
ar1 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2].mT
267-
ar2 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2].mT
268-
ar3 = dpt.ones(test_shape2, dtype="i4", order="C")[:20, ::-2].mT
269-
r4 = dpt.clip(ar1, ar2, ar3, order="K")
270-
assert r4.strides == (-1, n)
271-
r5 = dpt.clip(ar1, ar2, ar3, order="C")
272-
assert r5.strides == (n, 1)
333+
ar2 = dpt.ones(test_shape2, dtype=dt, order="C")[:20, ::-2].mT
273334

274-
r4 = dpt.clip(ar1, min=None, max=ar3, order="K")
335+
r4 = dpt.clip(ar1, min=None, max=ar2, order="K")
275336
assert r4.strides == (-1, n)
276-
r5 = dpt.clip(ar1, min=None, max=ar3, order="C")
337+
r5 = dpt.clip(ar1, min=None, max=ar2, order="C")
277338
assert r5.strides == (n, 1)
278339

279340

0 commit comments

Comments
 (0)