Skip to content

Commit 3d617ab

Browse files
committed
Add support for scalar arguments to xp.where
1 parent fe260b8 commit 3d617ab

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

array_api_strict/_flags.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
"2022.12",
2424
"2023.12",
2525
)
26+
# A placeholder value for the next version
27+
next_supported_version = "2024.12"
2628

2729
API_VERSION = default_version = "2023.12"
2830

@@ -134,7 +136,7 @@ def set_array_api_strict_flags(
134136
global API_VERSION, BOOLEAN_INDEXING, DATA_DEPENDENT_SHAPES, ENABLED_EXTENSIONS
135137

136138
if api_version is not None:
137-
if api_version not in supported_versions:
139+
if api_version not in supported_versions + (next_supported_version,):
138140
raise ValueError(f"Unsupported standard version {api_version!r}")
139141
if api_version == "2021.12":
140142
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12", stacklevel=2)

array_api_strict/_searching_functions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes
5-
from ._flags import requires_data_dependent_shapes, requires_api_version
5+
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
@@ -72,12 +72,19 @@ def searchsorted(
7272
# x1 must be 1-D, but NumPy already requires this.
7373
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
7474

75-
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
75+
def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array:
7676
"""
7777
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
7878
7979
See its docstring for more information.
8080
"""
81+
if get_array_api_strict_flags()['api_version'] > '2023.12':
82+
if isinstance(x1, (bool, float, int)):
83+
x1 = Array._new(np.asarray(x1), device=condition.device)
84+
85+
if isinstance(x2, (bool, float, int)):
86+
x2 = Array._new(np.asarray(x2), device=condition.device)
87+
8188
# Call result type here just to raise on disallowed type combinations
8289
_result_type(x1.dtype, x2.dtype)
8390

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
import array_api_strict as xp
4+
5+
from array_api_strict import ArrayAPIStrictFlags
6+
from array_api_strict._flags import next_supported_version
7+
8+
9+
def test_where_with_scalars():
10+
x = xp.asarray([1, 2, 3, 1])
11+
12+
# Versions up to and including 2023.12 don't support scalar arguments
13+
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14+
xp.where(x == 1, 42, 44)
15+
16+
# Versions after 2023.12 support scalar arguments
17+
with ArrayAPIStrictFlags(api_version=next_supported_version):
18+
x_where = xp.where(x == 1, 42, 44)
19+
20+
expected = xp.asarray([42, 44, 44, 42])
21+
assert xp.all(x_where == expected)

0 commit comments

Comments
 (0)