Skip to content

Get rid of call_origin in dpnp.where #1760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@


import dpctl.tensor as dpt
import numpy

import dpnp

from .dpnp_array import dpnp_array

# pylint: disable=no-name-in-module
from .dpnp_utils import (
call_origin,
get_usm_allocations,
)

Expand Down Expand Up @@ -298,35 +296,59 @@ def where(condition, x=None, y=None, /):

For full documentation refer to :obj:`numpy.where`.

Parameters
----------
condition : {dpnp.ndarray, usm_ndarray}
Where True, yield `x`, otherwise yield `y`.
x, y : {dpnp.ndarray, usm_ndarray, scalar}, optional
Values from which to choose. `x`, `y` and `condition` need to be
broadcastable to some shape.

Returns
-------
y : dpnp.ndarray
An array with elements from `x` where `condition` is True, and elements
from `y` elsewhere.

Limitations
-----------
Parameter `condition` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Parameters `x` and `y` are supported as either scalar, :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`
Otherwise the function will be executed sequentially on CPU.
Input array data types of `x` and `y` are limited by supported DPNP
:ref:`Data types`.

See Also
--------
:obj:`nonzero` : The function that is called when `x` and `y`are omitted.
:obj:`dpnp.choose` : Construct an array from an index array and a list of
arrays to choose from.
:obj:`dpnp.nonzero` : Return the indices of the elements that are non-zero.

Examples
--------
>>> import dpnp as dp
>>> a = dp.arange(10)
>>> d
>>> import dpnp as np
>>> a = np.arange(10)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> dp.where(a < 5, a, 10*a)
>>> np.where(a < 5, a, 10*a)
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])

This can be used on multidimensional arrays too:

>>> np.where(np.array([[True, False], [True, True]]),
... np.array([[1, 2], [3, 4]]),
... np.array([[9, 8], [7, 6]]))
array([[1, 8],
[3, 4]])

The shapes of x, y, and the condition are broadcast together:

>>> x, y = np.ogrid[:3, :4]
>>> np.where(x < y, x, 10 + y) # both x and 10+y are broadcast
array([[10, 0, 0, 0],
[10, 11, 1, 1],
[10, 11, 12, 2]])

>>> a = np.array([[0, 1, 2],
... [0, 2, 4],
... [0, 3, 6]])
>>> np.where(a < 4, a, -1) # -1 is broadcast
array([[ 0, 1, 2],
[ 0, 2, -1],
[ 0, 3, -1]])

"""

missing = (x is None, y is None).count(True)
Expand All @@ -336,34 +358,17 @@ def where(condition, x=None, y=None, /):
if missing == 2:
return dpnp.nonzero(condition)

if missing == 0:
if dpnp.is_supported_array_type(condition):
if numpy.isscalar(x) or numpy.isscalar(y):
# get USM type and queue to copy scalar from the host memory
# into a USM allocation
usm_type, queue = get_usm_allocations([condition, x, y])
x = (
dpt.asarray(x, usm_type=usm_type, sycl_queue=queue)
if numpy.isscalar(x)
else x
)
y = (
dpt.asarray(y, usm_type=usm_type, sycl_queue=queue)
if numpy.isscalar(y)
else y
)
if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type(
y
):
dpt_condition = (
condition.get_array()
if isinstance(condition, dpnp_array)
else condition
)
dpt_x = x.get_array() if isinstance(x, dpnp_array) else x
dpt_y = y.get_array() if isinstance(y, dpnp_array) else y
return dpnp_array._create_from_usm_ndarray(
dpt.where(dpt_condition, dpt_x, dpt_y)
)

return call_origin(numpy.where, condition, x, y)
usm_x = dpnp.get_usm_ndarray_or_scalar(x)
usm_y = dpnp.get_usm_ndarray_or_scalar(y)
usm_condition = dpnp.get_usm_ndarray(condition)

usm_type, queue = get_usm_allocations([condition, x, y])
if dpnp.isscalar(usm_x):
usm_x = dpt.asarray(usm_x, usm_type=usm_type, sycl_queue=queue)

if dpnp.isscalar(usm_y):
usm_y = dpt.asarray(usm_y, usm_type=usm_type, sycl_queue=queue)

return dpnp_array._create_from_usm_ndarray(
dpt.where(usm_condition, usm_x, usm_y)
)
19 changes: 0 additions & 19 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,22 +906,3 @@ def test_triu_indices_from(array, k):
result = dpnp.triu_indices_from(ia, k)
expected = numpy.triu_indices_from(a, k)
assert_array_equal(expected, result)


@pytest.mark.parametrize("cond_dtype", get_all_dtypes())
@pytest.mark.parametrize("scalar_dtype", get_all_dtypes(no_none=True))
def test_where_with_scalars(cond_dtype, scalar_dtype):
a = numpy.array([-1, 0, 1, 0], dtype=cond_dtype)
ia = dpnp.array(a)

result = dpnp.where(ia, scalar_dtype(1), scalar_dtype(0))
expected = numpy.where(a, scalar_dtype(1), scalar_dtype(0))
assert_array_equal(expected, result)

result = dpnp.where(ia, ia * 2, scalar_dtype(0))
expected = numpy.where(a, a * 2, scalar_dtype(0))
assert_array_equal(expected, result)

result = dpnp.where(ia, scalar_dtype(1), dpnp.array(0))
expected = numpy.where(a, scalar_dtype(1), numpy.array(0))
assert_array_equal(expected, result)
188 changes: 187 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dpctl.tensor as dpt
import numpy
import pytest
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_array_equal, assert_raises

import dpnp

Expand Down Expand Up @@ -92,3 +92,189 @@ def test_nanargmax_nanargmin_error(func):
# All-NaN slice encountered -> ValueError
with pytest.raises(ValueError):
getattr(dpnp, func)(ia, axis=0)


class TestWhere:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_basic(self, dtype):
a = numpy.ones(53, dtype=bool)
ia = dpnp.array(a)

np_res = numpy.where(a, dtype(0), dtype(1))
dpnp_res = dpnp.where(ia, dtype(0), dtype(1))
assert_array_equal(np_res, dpnp_res)

np_res = numpy.where(~a, dtype(0), dtype(1))
dpnp_res = dpnp.where(~ia, dtype(0), dtype(1))
assert_array_equal(np_res, dpnp_res)

d = numpy.ones_like(a).astype(dtype)
e = numpy.zeros_like(d)
a[7] = False

ia[7] = False
id = dpnp.array(d)
ie = dpnp.array(e)

np_res = numpy.where(a, e, e)
dpnp_res = dpnp.where(ia, ie, ie)
assert_array_equal(np_res, dpnp_res)

np_res = numpy.where(a, d, e)
dpnp_res = dpnp.where(ia, id, ie)
assert_array_equal(np_res, dpnp_res)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
@pytest.mark.parametrize(
"slice_a, slice_d, slice_e",
[
pytest.param(
slice(None, None, None),
slice(None, None, None),
slice(0, 1, None),
),
pytest.param(
slice(None, None, None),
slice(0, 1, None),
slice(None, None, None),
),
pytest.param(
slice(None, None, 2), slice(None, None, 2), slice(None, None, 2)
),
pytest.param(
slice(1, None, 2), slice(1, None, 2), slice(1, None, 2)
),
pytest.param(
slice(None, None, 3), slice(None, None, 3), slice(None, None, 3)
),
pytest.param(
slice(1, None, 3), slice(1, None, 3), slice(1, None, 3)
),
pytest.param(
slice(None, None, -2),
slice(None, None, -2),
slice(None, None, -2),
),
pytest.param(
slice(None, None, -3),
slice(None, None, -3),
slice(None, None, -3),
),
pytest.param(
slice(1, None, -3), slice(1, None, -3), slice(1, None, -3)
),
],
)
def test_strided(self, dtype, slice_a, slice_d, slice_e):
a = numpy.ones(53, dtype=bool)
a[7] = False
d = numpy.ones_like(a).astype(dtype)
e = numpy.zeros_like(d)

ia = dpnp.array(a)
id = dpnp.array(d)
ie = dpnp.array(e)

np_res = numpy.where(a[slice_a], d[slice_d], e[slice_e])
dpnp_res = dpnp.where(ia[slice_a], id[slice_d], ie[slice_e])
assert_array_equal(np_res, dpnp_res)

def test_zero_sized(self):
a = numpy.array([], dtype=bool).reshape(0, 3)
b = numpy.array([], dtype=numpy.float32).reshape(0, 3)

ia = dpnp.array(a)
ib = dpnp.array(b)

np_res = numpy.where(a, 0, b)
dpnp_res = dpnp.where(ia, 0, ib)
assert_array_equal(np_res, dpnp_res)

def test_ndim(self):
a = numpy.zeros((2, 25))
b = numpy.ones((2, 25))
c = numpy.array([True, False])

ia = dpnp.array(a)
ib = dpnp.array(b)
ic = dpnp.array(c)

np_res = numpy.where(c[:, numpy.newaxis], a, b)
dpnp_res = dpnp.where(ic[:, dpnp.newaxis], ia, ib)
assert_array_equal(np_res, dpnp_res)

np_res = numpy.where(c, a.T, b.T)
dpnp_res = numpy.where(ic, ia.T, ib.T)
assert_array_equal(np_res, dpnp_res)

def test_dtype_mix(self):
a = numpy.uint32(1)
b = numpy.array(
[5.0, 0.0, 3.0, 2.0, -1.0, -4.0, 0.0, -10.0, 10.0, 1.0, 0.0, 3.0],
dtype=numpy.float32,
)
c = numpy.array(
[
False,
True,
False,
False,
False,
False,
True,
False,
False,
False,
True,
False,
]
)

ia = dpnp.array(a)
ib = dpnp.array(b)
ic = dpnp.array(c)

np_res = numpy.where(c, a, b)
dpnp_res = dpnp.where(ic, ia, ib)
assert_array_equal(np_res, dpnp_res)

b = b.astype(numpy.int64)
ib = dpnp.array(b)

np_res = numpy.where(c, a, b)
dpnp_res = dpnp.where(ic, ia, ib)
assert_array_equal(np_res, dpnp_res)

# non bool mask
c = c.astype(int)
c[c != 0] = 34242324
ic = dpnp.array(c)

np_res = numpy.where(c, a, b)
dpnp_res = dpnp.where(ic, ia, ib)
assert_array_equal(np_res, dpnp_res)

# invert
tmpmask = c != 0
c[c == 0] = 41247212
c[tmpmask] = 0
ic = dpnp.array(c)

np_res = numpy.where(c, a, b)
dpnp_res = dpnp.where(ic, ia, ib)
assert_array_equal(np_res, dpnp_res)

def test_error(self):
c = dpnp.array([True, True])
a = dpnp.ones((4, 5))
b = dpnp.ones((5, 5))
assert_raises(ValueError, dpnp.where, c, a, a)
assert_raises(ValueError, dpnp.where, c[0], a, b)

def test_empty_result(self):
a = numpy.zeros((1, 1))
ia = dpnp.array(a)

np_res = numpy.vstack(numpy.where(a == 99.0))
dpnp_res = dpnp.vstack(dpnp.where(ia == 99.0))
assert_array_equal(np_res, dpnp_res)
18 changes: 18 additions & 0 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,24 @@ def test_grid(device, func):
assert_sycl_queue_equal(x.sycl_queue, sycl_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
ids=[device.filter_string for device in valid_devices],
)
def test_where(device):
a = numpy.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]])
ia = dpnp.array(a, device=device)

result = dpnp.where(ia < 4, ia, -1)
expected = numpy.where(a < 4, a, -1)
assert_allclose(expected, result)

expected_queue = ia.get_array().sycl_queue
result_queue = result.get_array().sycl_queue
assert_sycl_queue_equal(result_queue, expected_queue)


@pytest.mark.parametrize(
"device",
valid_devices,
Expand Down
Loading