diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index afee030..0595594 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -327,7 +327,7 @@ def _normalize_two_args(x1, x2) -> Tuple[Array, Array]: # Note: A large fraction of allowed indices are disallowed here (see the # docstring below) - def _validate_index(self, key): + def _validate_index(self, key, op="getitem"): """ Validate an index according to the array API. @@ -390,11 +390,16 @@ def _validate_index(self, key): "zero-dimensional integer arrays and boolean arrays " "are specified in the Array API." ) + if op == "setitem": + if isinstance(i, Array) and i.dtype in _integer_dtypes: + raise IndexError("Fancy indexing __setitem__ is not supported.") nonexpanding_key = [] single_axes = [] n_ellipsis = 0 key_has_mask = False + key_has_index_array = False + key_has_slices = False for i in _key: if i is not None: nonexpanding_key.append(i) @@ -403,6 +408,8 @@ def _validate_index(self, key): if isinstance(i, Array): if i.dtype in _boolean_dtypes: key_has_mask = True + elif i.dtype in _integer_dtypes: + key_has_index_array = True single_axes.append(i) else: # i must not be an array here, to avoid elementwise equals @@ -410,6 +417,8 @@ def _validate_index(self, key): n_ellipsis += 1 else: single_axes.append(i) + if isinstance(i, slice): + key_has_slices = True n_single_axes = len(single_axes) if n_ellipsis > 1: @@ -427,6 +436,12 @@ def _validate_index(self, key): "specified in the Array API." ) + if (key_has_index_array and (n_ellipsis > 0 or key_has_slices or key_has_mask)): + raise IndexError( + "Integer index arrays are only allowed with integer indices; " + f"got {key}." + ) + if n_ellipsis == 0: indexed_shape = self.shape else: @@ -483,14 +498,9 @@ def _validate_index(self, key): "Array API when the array is the sole index." ) if not get_array_api_strict_flags()['boolean_indexing']: - raise RuntimeError("The boolean_indexing flag has been disabled for array-api-strict") - - elif i.dtype in _integer_dtypes and i.ndim != 0: - raise IndexError( - f"Single-axes index {i} is a non-zero-dimensional " - "integer array, but advanced integer indexing is not " - "specified in the Array API." - ) + raise RuntimeError( + "The boolean_indexing flag has been disabled for array-api-strict" + ) elif isinstance(i, tuple): raise IndexError( f"Single-axes index {i} is a tuple, but nested tuple " @@ -902,7 +912,7 @@ def __setitem__( """ # Note: Only indices required by the spec are allowed. See the # docstring of _validate_index - self._validate_index(key) + self._validate_index(key, op="setitem") if isinstance(key, Array): # Indexing self._array with array_api_strict arrays can be erroneous key = key._array diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index edfa073..ef76c28 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from .. import ones, asarray, result_type, all, equal +from .. import ones, arange, reshape, asarray, result_type, all, equal from .._array_object import Array, CPU_DEVICE, Device from .._dtypes import ( _all_dtypes, @@ -45,35 +45,46 @@ def test_validate_index(): a = ones((3, 4)) # Out of bounds slices are not allowed - assert_raises(IndexError, lambda: a[:4]) - assert_raises(IndexError, lambda: a[:-4]) - assert_raises(IndexError, lambda: a[:3:-1]) - assert_raises(IndexError, lambda: a[:-5:-1]) - assert_raises(IndexError, lambda: a[4:]) - assert_raises(IndexError, lambda: a[-4:]) - assert_raises(IndexError, lambda: a[4::-1]) - assert_raises(IndexError, lambda: a[-4::-1]) - - assert_raises(IndexError, lambda: a[...,:5]) - assert_raises(IndexError, lambda: a[...,:-5]) - assert_raises(IndexError, lambda: a[...,:5:-1]) - assert_raises(IndexError, lambda: a[...,:-6:-1]) - assert_raises(IndexError, lambda: a[...,5:]) - assert_raises(IndexError, lambda: a[...,-5:]) - assert_raises(IndexError, lambda: a[...,5::-1]) - assert_raises(IndexError, lambda: a[...,-5::-1]) + assert_raises(IndexError, lambda: a[:4, 0]) + assert_raises(IndexError, lambda: a[:-4, 0]) + assert_raises(IndexError, lambda: a[:3:-1]) # XXX raises for a wrong reason + assert_raises(IndexError, lambda: a[:-5:-1, 0]) + assert_raises(IndexError, lambda: a[4:, 0]) + assert_raises(IndexError, lambda: a[-4:, 0]) + assert_raises(IndexError, lambda: a[4::-1, 0]) + assert_raises(IndexError, lambda: a[-4::-1, 0]) + + assert_raises(IndexError, lambda: a[..., :5]) + assert_raises(IndexError, lambda: a[..., :-5]) + assert_raises(IndexError, lambda: a[..., :5:-1]) + assert_raises(IndexError, lambda: a[..., :-6:-1]) + assert_raises(IndexError, lambda: a[..., 5:]) + assert_raises(IndexError, lambda: a[..., -5:]) + assert_raises(IndexError, lambda: a[..., 5::-1]) + assert_raises(IndexError, lambda: a[..., -5::-1]) # Boolean indices cannot be part of a larger tuple index - assert_raises(IndexError, lambda: a[a[:,0]==1,0]) - assert_raises(IndexError, lambda: a[a[:,0]==1,...]) - assert_raises(IndexError, lambda: a[..., a[0]==1]) + assert_raises(IndexError, lambda: a[a[:, 0] == 1, 0]) + assert_raises(IndexError, lambda: a[a[:, 0] == 1, ...]) + assert_raises(IndexError, lambda: a[..., a[0] == 1]) assert_raises(IndexError, lambda: a[[True, True, True]]) assert_raises(IndexError, lambda: a[(True, True, True),]) - # Integer array indices are not allowed (except for 0-D) - idx = asarray([[0, 1]]) - assert_raises(IndexError, lambda: a[idx]) - assert_raises(IndexError, lambda: a[idx,]) + # Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed + idx = asarray([0, 1]) + assert_raises(IndexError, lambda: a[..., idx]) + assert_raises(IndexError, lambda: a[:, idx]) + assert_raises(IndexError, lambda: a[asarray([True, True]), idx]) + + # 1D integer array indices must have the same length + idx1 = asarray([0, 1]) + idx2 = asarray([0, 1, 1]) + assert_raises(IndexError, lambda: a[idx1, idx2]) + + # Non-integer array indices are not allowed + assert_raises(IndexError, lambda: a[ones(2), 0]) + + # Array-likes (lists, tuples) are not allowed as indices assert_raises(IndexError, lambda: a[[0, 1]]) assert_raises(IndexError, lambda: a[(0, 1), (0, 1)]) assert_raises(IndexError, lambda: a[[0, 1]]) @@ -87,6 +98,45 @@ def test_validate_index(): assert_raises(IndexError, lambda: a[0,]) assert_raises(IndexError, lambda: a[0]) assert_raises(IndexError, lambda: a[:]) + assert_raises(IndexError, lambda: a[idx]) + + +def test_indexing_arrays(): + # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed + + # 1D array + a = arange(5) + idx = asarray([1, 0, 1, 2, -1]) + a_idx = a[idx] + + a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + # setitem with arrays is not allowed + with assert_raises(IndexError): + a[idx] = 42 + + # mixed array and integer indexing + a = reshape(arange(3*4), (3, 4)) + idx = asarray([1, 0, 1, 2, -1]) + a_idx = a[idx, 1] + + a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + # index with two arrays + a_idx = a[idx, idx] + a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) + assert all(a_idx == a_idx_loop) + + # setitem with arrays is not allowed + with assert_raises(IndexError): + a[idx, idx] = 42 + + # smoke test indexing with ndim > 1 arrays + idx = idx[..., None] + a[idx, idx] + def test_promoted_scalar_inherits_device(): device1 = Device("device1")