Skip to content

Commit c67645a

Browse files
committed
Allow boolean indexing
1 parent 4fa7994 commit c67645a

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

pytensor/xtensor/indexing.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,27 @@ def as_idx_variable(idx):
4444
dims = (dim,)
4545
else:
4646
dims = tuple(dim)
47-
idx = xtensor_from_tensor(as_tensor(idx), dims=dims)
47+
idx = as_xtensor(as_tensor(idx), dims=dims)
4848
else:
4949
# Must be integer indices, we already counted for None and slices
5050
try:
5151
idx = as_xtensor(idx)
5252
except TypeError:
5353
idx = as_tensor(idx)
5454
if idx.type.dtype == "bool":
55-
raise NotImplementedError("Boolean indexing not yet supported")
56-
if idx.type.dtype not in discrete_dtypes:
55+
if idx.type.ndim != 1:
56+
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379
57+
# Otherwise, it is always restricted to 1d boolean indexing arrays
58+
raise NotImplementedError(
59+
"Only 1d boolean indexing arrays are supported"
60+
)
61+
# Convert to nonzero indices
62+
if isinstance(idx.type, XTensorType):
63+
idx = as_xtensor(idx.values.nonzero()[0], dims=idx.type.dims)
64+
else:
65+
idx = idx.nonzero()[0]
66+
elif idx.type.dtype not in discrete_dtypes:
5767
raise TypeError("Numerical indices must be integers or boolean")
58-
if idx.type.dtype == "bool" and idx.type.ndim == 0:
59-
# This can't be triggered right now, but will once we lift the boolean restriction
60-
raise NotImplementedError("Scalar boolean indices not supported")
6168
return idx
6269

6370

tests/xtensor/test_indexing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,39 @@ def test_scalar_integer_indexing(dims_order):
295295
expected_res2 = x_test[tuple(idxs)]
296296
xr_assert_allclose(res1, expected_res1)
297297
xr_assert_allclose(res2, expected_res2)
298+
299+
300+
def test_unsupported_boolean_indexing():
301+
x = xtensor(dims=("a", "b"), shape=(3, 5))
302+
303+
mat_idx = xtensor("idx", dtype=bool, shape=(4, 2), dims=("a", "b"))
304+
scalar_idx = mat_idx.isel(a=0, b=1)
305+
306+
for idx in (mat_idx, mat_idx.values, scalar_idx, scalar_idx.values):
307+
with pytest.raises(
308+
NotImplementedError,
309+
match="Only 1d boolean indexing arrays are supported",
310+
):
311+
x[idx]
312+
313+
314+
def test_boolean_indexing():
315+
x = xtensor("x", shape=(8, 7), dims=("a", "b"))
316+
bool_idx = xtensor("bool_idx", dtype=bool, shape=(8,), dims=("a",))
317+
int_idx = xtensor("int_idx", dtype=int, shape=(4, 3), dims=("a", "new_dim"))
318+
319+
out_vectorized = x[bool_idx, int_idx]
320+
out_orthogonal = x[bool_idx, int_idx.rename(a="b")]
321+
fn = xr_function([x, bool_idx, int_idx], [out_vectorized, out_orthogonal])
322+
323+
x_test = xr_arange_like(x)
324+
bool_idx_test = DataArray(np.array([True, False] * 4, dtype=bool), dims=("a",))
325+
int_idx_test = DataArray(
326+
np.random.binomial(n=4, p=0.5, size=(4, 3)),
327+
dims=("a", "new_dim"),
328+
)
329+
res1, res2 = fn(x_test, bool_idx_test, int_idx_test)
330+
expected_res1 = x_test[bool_idx_test, int_idx_test]
331+
expected_res2 = x_test[bool_idx_test, int_idx_test.rename(a="b")]
332+
xr_assert_allclose(res1, expected_res1)
333+
xr_assert_allclose(res2, expected_res2)

0 commit comments

Comments
 (0)