Skip to content

Commit fd44f1f

Browse files
committed
bug: where: check condition is boolean
1 parent 444830f commit fd44f1f

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

array_api_strict/_searching_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array:
8080
"""
8181
# Call result type here just to raise on disallowed type combinations
8282
_result_type(x1.dtype, x2.dtype)
83+
84+
if condition.dtype != bool:
85+
raise TypeError("`condition` must be have a boolean data type")
8386

8487
if len({a.device for a in (condition, x1, x2)}) > 1:
8588
raise ValueError("where inputs must all be on the same device")

0 commit comments

Comments
 (0)