Skip to content

Commit a3d6ae3

Browse files
Get rid of call_origin in dpnp.where (#1760)
* Get redi of call_origin in dpnp.where * Update dpnp/dpnp_iface_searching.py Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com> --------- Co-authored-by: vlad-perevezentsev <vladislav.perevezentsev@intel.com>
1 parent 9872abc commit a3d6ae3

File tree

5 files changed

+265
-68
lines changed

5 files changed

+265
-68
lines changed

dpnp/dpnp_iface_searching.py

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,13 @@
3939

4040

4141
import dpctl.tensor as dpt
42-
import numpy
4342

4443
import dpnp
4544

4645
from .dpnp_array import dpnp_array
4746

4847
# pylint: disable=no-name-in-module
4948
from .dpnp_utils import (
50-
call_origin,
5149
get_usm_allocations,
5250
)
5351

@@ -298,35 +296,59 @@ def where(condition, x=None, y=None, /):
298296
299297
For full documentation refer to :obj:`numpy.where`.
300298
299+
Parameters
300+
----------
301+
condition : {dpnp.ndarray, usm_ndarray}
302+
Where True, yield `x`, otherwise yield `y`.
303+
x, y : {dpnp.ndarray, usm_ndarray, scalar}, optional
304+
Values from which to choose. `x`, `y` and `condition` need to be
305+
broadcastable to some shape.
306+
301307
Returns
302308
-------
303309
y : dpnp.ndarray
304310
An array with elements from `x` where `condition` is True, and elements
305311
from `y` elsewhere.
306312
307-
Limitations
308-
-----------
309-
Parameter `condition` is supported as either :class:`dpnp.ndarray`
310-
or :class:`dpctl.tensor.usm_ndarray`.
311-
Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
312-
or :class:`dpctl.tensor.usm_ndarray`
313-
Otherwise the function will be executed sequentially on CPU.
314-
Input array data types of `x` and `y` are limited by supported DPNP
315-
:ref:`Data types`.
316-
317313
See Also
318314
--------
319-
:obj:`nonzero` : The function that is called when `x` and `y`are omitted.
315+
:obj:`dpnp.choose` : Construct an array from an index array and a list of
316+
arrays to choose from.
317+
:obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero.
320318
321319
Examples
322320
--------
323-
>>> import dpnp as dp
324-
>>> a = dp.arange(10)
325-
>>> d
321+
>>> import dpnp as np
322+
>>> a = np.arange(10)
323+
>>> a
326324
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
327-
>>> dp.where(a < 5, a, 10*a)
325+
>>> np.where(a < 5, a, 10*a)
328326
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
329327
328+
This can be used on multidimensional arrays too:
329+
330+
>>> np.where(np.array([[True, False], [True, True]]),
331+
... np.array([[1, 2], [3, 4]]),
332+
... np.array([[9, 8], [7, 6]]))
333+
array([[1, 8],
334+
[3, 4]])
335+
336+
The shapes of x, y, and the condition are broadcast together:
337+
338+
>>> x, y = np.ogrid[:3, :4]
339+
>>> np.where(x < y, x, 10 + y) # both x and 10+y are broadcast
340+
array([[10, 0, 0, 0],
341+
[10, 11, 1, 1],
342+
[10, 11, 12, 2]])
343+
344+
>>> a = np.array([[0, 1, 2],
345+
... [0, 2, 4],
346+
... [0, 3, 6]])
347+
>>> np.where(a < 4, a, -1) # -1 is broadcast
348+
array([[ 0, 1, 2],
349+
[ 0, 2, -1],
350+
[ 0, 3, -1]])
351+
330352
"""
331353

332354
missing = (x is None, y is None).count(True)
@@ -336,34 +358,17 @@ def where(condition, x=None, y=None, /):
336358
if missing == 2:
337359
return dpnp.nonzero(condition)
338360

339-
if missing == 0:
340-
if dpnp.is_supported_array_type(condition):
341-
if numpy.isscalar(x) or numpy.isscalar(y):
342-
# get USM type and queue to copy scalar from the host memory
343-
# into a USM allocation
344-
usm_type, queue = get_usm_allocations([condition, x, y])
345-
x = (
346-
dpt.asarray(x, usm_type=usm_type, sycl_queue=queue)
347-
if numpy.isscalar(x)
348-
else x
349-
)
350-
y = (
351-
dpt.asarray(y, usm_type=usm_type, sycl_queue=queue)
352-
if numpy.isscalar(y)
353-
else y
354-
)
355-
if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
356-
y
357-
):
358-
dpt_condition = (
359-
condition.get_array()
360-
if isinstance(condition, dpnp_array)
361-
else condition
362-
)
363-
dpt_x = x.get_array() if isinstance(x, dpnp_array) else x
364-
dpt_y = y.get_array() if isinstance(y, dpnp_array) else y
365-
return dpnp_array._create_from_usm_ndarray(
366-
dpt.where(dpt_condition, dpt_x, dpt_y)
367-
)
368-
369-
return call_origin(numpy.where, condition, x, y)
361+
usm_x = dpnp.get_usm_ndarray_or_scalar(x)
362+
usm_y = dpnp.get_usm_ndarray_or_scalar(y)
363+
usm_condition = dpnp.get_usm_ndarray(condition)
364+
365+
usm_type, queue = get_usm_allocations([condition, x, y])
366+
if dpnp.isscalar(usm_x):
367+
usm_x = dpt.asarray(usm_x, usm_type=usm_type, sycl_queue=queue)
368+
369+
if dpnp.isscalar(usm_y):
370+
usm_y = dpt.asarray(usm_y, usm_type=usm_type, sycl_queue=queue)
371+
372+
return dpnp_array._create_from_usm_ndarray(
373+
dpt.where(usm_condition, usm_x, usm_y)
374+
)

tests/test_indexing.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -906,22 +906,3 @@ def test_triu_indices_from(array, k):
906906
result = dpnp.triu_indices_from(ia, k)
907907
expected = numpy.triu_indices_from(a, k)
908908
assert_array_equal(expected, result)
909-
910-
911-
@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
912-
@pytest.mark.parametrize("scalar_dtype", get_all_dtypes(no_none=True))
913-
def test_where_with_scalars(cond_dtype, scalar_dtype):
914-
a = numpy.array([-1, 0, 1, 0], dtype=cond_dtype)
915-
ia = dpnp.array(a)
916-
917-
result = dpnp.where(ia, scalar_dtype(1), scalar_dtype(0))
918-
expected = numpy.where(a, scalar_dtype(1), scalar_dtype(0))
919-
assert_array_equal(expected, result)
920-
921-
result = dpnp.where(ia, ia * 2, scalar_dtype(0))
922-
expected = numpy.where(a, a * 2, scalar_dtype(0))
923-
assert_array_equal(expected, result)
924-
925-
result = dpnp.where(ia, scalar_dtype(1), dpnp.array(0))
926-
expected = numpy.where(a, scalar_dtype(1), numpy.array(0))
927-
assert_array_equal(expected, result)

tests/test_search.py

Lines changed: 187 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import dpctl.tensor as dpt
22
import numpy
33
import pytest
4-
from numpy.testing import assert_allclose
4+
from numpy.testing import assert_allclose, assert_array_equal, assert_raises
55

66
import dpnp
77

@@ -92,3 +92,189 @@ def test_nanargmax_nanargmin_error(func):
9292
# All-NaN slice encountered -> ValueError
9393
with pytest.raises(ValueError):
9494
getattr(dpnp, func)(ia, axis=0)
95+
96+
97+
class TestWhere:
98+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
99+
def test_basic(self, dtype):
100+
a = numpy.ones(53, dtype=bool)
101+
ia = dpnp.array(a)
102+
103+
np_res = numpy.where(a, dtype(0), dtype(1))
104+
dpnp_res = dpnp.where(ia, dtype(0), dtype(1))
105+
assert_array_equal(np_res, dpnp_res)
106+
107+
np_res = numpy.where(~a, dtype(0), dtype(1))
108+
dpnp_res = dpnp.where(~ia, dtype(0), dtype(1))
109+
assert_array_equal(np_res, dpnp_res)
110+
111+
d = numpy.ones_like(a).astype(dtype)
112+
e = numpy.zeros_like(d)
113+
a[7] = False
114+
115+
ia[7] = False
116+
id = dpnp.array(d)
117+
ie = dpnp.array(e)
118+
119+
np_res = numpy.where(a, e, e)
120+
dpnp_res = dpnp.where(ia, ie, ie)
121+
assert_array_equal(np_res, dpnp_res)
122+
123+
np_res = numpy.where(a, d, e)
124+
dpnp_res = dpnp.where(ia, id, ie)
125+
assert_array_equal(np_res, dpnp_res)
126+
127+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
128+
@pytest.mark.parametrize(
129+
"slice_a, slice_d, slice_e",
130+
[
131+
pytest.param(
132+
slice(None, None, None),
133+
slice(None, None, None),
134+
slice(0, 1, None),
135+
),
136+
pytest.param(
137+
slice(None, None, None),
138+
slice(0, 1, None),
139+
slice(None, None, None),
140+
),
141+
pytest.param(
142+
slice(None, None, 2), slice(None, None, 2), slice(None, None, 2)
143+
),
144+
pytest.param(
145+
slice(1, None, 2), slice(1, None, 2), slice(1, None, 2)
146+
),
147+
pytest.param(
148+
slice(None, None, 3), slice(None, None, 3), slice(None, None, 3)
149+
),
150+
pytest.param(
151+
slice(1, None, 3), slice(1, None, 3), slice(1, None, 3)
152+
),
153+
pytest.param(
154+
slice(None, None, -2),
155+
slice(None, None, -2),
156+
slice(None, None, -2),
157+
),
158+
pytest.param(
159+
slice(None, None, -3),
160+
slice(None, None, -3),
161+
slice(None, None, -3),
162+
),
163+
pytest.param(
164+
slice(1, None, -3), slice(1, None, -3), slice(1, None, -3)
165+
),
166+
],
167+
)
168+
def test_strided(self, dtype, slice_a, slice_d, slice_e):
169+
a = numpy.ones(53, dtype=bool)
170+
a[7] = False
171+
d = numpy.ones_like(a).astype(dtype)
172+
e = numpy.zeros_like(d)
173+
174+
ia = dpnp.array(a)
175+
id = dpnp.array(d)
176+
ie = dpnp.array(e)
177+
178+
np_res = numpy.where(a[slice_a], d[slice_d], e[slice_e])
179+
dpnp_res = dpnp.where(ia[slice_a], id[slice_d], ie[slice_e])
180+
assert_array_equal(np_res, dpnp_res)
181+
182+
def test_zero_sized(self):
183+
a = numpy.array([], dtype=bool).reshape(0, 3)
184+
b = numpy.array([], dtype=numpy.float32).reshape(0, 3)
185+
186+
ia = dpnp.array(a)
187+
ib = dpnp.array(b)
188+
189+
np_res = numpy.where(a, 0, b)
190+
dpnp_res = dpnp.where(ia, 0, ib)
191+
assert_array_equal(np_res, dpnp_res)
192+
193+
def test_ndim(self):
194+
a = numpy.zeros((2, 25))
195+
b = numpy.ones((2, 25))
196+
c = numpy.array([True, False])
197+
198+
ia = dpnp.array(a)
199+
ib = dpnp.array(b)
200+
ic = dpnp.array(c)
201+
202+
np_res = numpy.where(c[:, numpy.newaxis], a, b)
203+
dpnp_res = dpnp.where(ic[:, dpnp.newaxis], ia, ib)
204+
assert_array_equal(np_res, dpnp_res)
205+
206+
np_res = numpy.where(c, a.T, b.T)
207+
dpnp_res = numpy.where(ic, ia.T, ib.T)
208+
assert_array_equal(np_res, dpnp_res)
209+
210+
def test_dtype_mix(self):
211+
a = numpy.uint32(1)
212+
b = numpy.array(
213+
[5.0, 0.0, 3.0, 2.0, -1.0, -4.0, 0.0, -10.0, 10.0, 1.0, 0.0, 3.0],
214+
dtype=numpy.float32,
215+
)
216+
c = numpy.array(
217+
[
218+
False,
219+
True,
220+
False,
221+
False,
222+
False,
223+
False,
224+
True,
225+
False,
226+
False,
227+
False,
228+
True,
229+
False,
230+
]
231+
)
232+
233+
ia = dpnp.array(a)
234+
ib = dpnp.array(b)
235+
ic = dpnp.array(c)
236+
237+
np_res = numpy.where(c, a, b)
238+
dpnp_res = dpnp.where(ic, ia, ib)
239+
assert_array_equal(np_res, dpnp_res)
240+
241+
b = b.astype(numpy.int64)
242+
ib = dpnp.array(b)
243+
244+
np_res = numpy.where(c, a, b)
245+
dpnp_res = dpnp.where(ic, ia, ib)
246+
assert_array_equal(np_res, dpnp_res)
247+
248+
# non bool mask
249+
c = c.astype(int)
250+
c[c != 0] = 34242324
251+
ic = dpnp.array(c)
252+
253+
np_res = numpy.where(c, a, b)
254+
dpnp_res = dpnp.where(ic, ia, ib)
255+
assert_array_equal(np_res, dpnp_res)
256+
257+
# invert
258+
tmpmask = c != 0
259+
c[c == 0] = 41247212
260+
c[tmpmask] = 0
261+
ic = dpnp.array(c)
262+
263+
np_res = numpy.where(c, a, b)
264+
dpnp_res = dpnp.where(ic, ia, ib)
265+
assert_array_equal(np_res, dpnp_res)
266+
267+
def test_error(self):
268+
c = dpnp.array([True, True])
269+
a = dpnp.ones((4, 5))
270+
b = dpnp.ones((5, 5))
271+
assert_raises(ValueError, dpnp.where, c, a, a)
272+
assert_raises(ValueError, dpnp.where, c[0], a, b)
273+
274+
def test_empty_result(self):
275+
a = numpy.zeros((1, 1))
276+
ia = dpnp.array(a)
277+
278+
np_res = numpy.vstack(numpy.where(a == 99.0))
279+
dpnp_res = dpnp.vstack(dpnp.where(ia == 99.0))
280+
assert_array_equal(np_res, dpnp_res)

tests/test_sycl_queue.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,24 @@ def test_grid(device, func):
17701770
assert_sycl_queue_equal(x.sycl_queue, sycl_queue)
17711771

17721772

1773+
@pytest.mark.parametrize(
1774+
"device",
1775+
valid_devices,
1776+
ids=[device.filter_string for device in valid_devices],
1777+
)
1778+
def test_where(device):
1779+
a = numpy.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]])
1780+
ia = dpnp.array(a, device=device)
1781+
1782+
result = dpnp.where(ia < 4, ia, -1)
1783+
expected = numpy.where(a < 4, a, -1)
1784+
assert_allclose(expected, result)
1785+
1786+
expected_queue = ia.get_array().sycl_queue
1787+
result_queue = result.get_array().sycl_queue
1788+
assert_sycl_queue_equal(result_queue, expected_queue)
1789+
1790+
17731791
@pytest.mark.parametrize(
17741792
"device",
17751793
valid_devices,

0 commit comments

Comments
 (0)