From 202c46bb1afaa892a99afe8e96093811bae55f37 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 9 Dec 2024 17:55:52 +0000 Subject: [PATCH] bug: where: check `condition` is boolean --- array_api_strict/_searching_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/array_api_strict/_searching_functions.py b/array_api_strict/_searching_functions.py index 0d7c0c8..5460b30 100644 --- a/array_api_strict/_searching_functions.py +++ b/array_api_strict/_searching_functions.py @@ -1,7 +1,7 @@ from __future__ import annotations from ._array_object import Array -from ._dtypes import _result_type, _real_numeric_dtypes +from ._dtypes import _result_type, _real_numeric_dtypes, bool as _bool from ._flags import requires_data_dependent_shapes, requires_api_version from typing import TYPE_CHECKING @@ -80,6 +80,9 @@ def where(condition: Array, x1: Array, x2: Array, /) -> Array: """ # Call result type here just to raise on disallowed type combinations _result_type(x1.dtype, x2.dtype) + + if condition.dtype != _bool: + raise TypeError("`condition` must be have a boolean data type") if len({a.device for a in (condition, x1, x2)}) > 1: raise ValueError("where inputs must all be on the same device")