diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 3e120e7e..7f2698c4 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -242,6 +242,81 @@ def test_setitem_masking(shape, data): ) +# ### Fancy indexing ### + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@pytest.mark.parametrize("idx_max_dims", [1, None]) +@given(shape=hh.shapes(min_dims=2), data=st.data()) +def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims): + # min_dims=2 : test multidim `x` arrays + # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None + _test_getitem_arrays_and_ints(shape, data, idx_max_dims) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@pytest.mark.parametrize("idx_max_dims", [1, None]) +@given(shape=hh.shapes(min_dims=1), data=st.data()) +def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims): + # min_dims=1 : favor 1D `x` arrays + # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None + _test_getitem_arrays_and_ints(shape, data, idx_max_dims) + + +def _test_getitem_arrays_and_ints(shape, data, idx_max_dims): + assume((len(shape) > 0) and all(sh > 0 for sh in shape)) + + dtype = xp.int32 + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + + # draw a mix of ints and index arrays + arr_index = [data.draw(st.booleans()) for _ in range(len(shape))] + assume(sum(arr_index) > 0) + + # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY + # max_dims=None ==> multidim indexing arrays + if sum(arr_index) > 0: + index_shapes = data.draw( + hh.mutually_broadcastable_shapes( + sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1 + ) + ) + index_shapes = list(index_shapes) + + # prepare the indexing tuple, a mix of integer indices and index arrays + key = [] + for i,typ in enumerate(arr_index): + if typ: + # draw an array index + this_idx = data.draw( + xps.arrays( + dtype, + shape=index_shapes.pop(), + elements=st.integers(0, shape[i]-1) + ) + ) + key.append(this_idx) + + else: + # draw an integer + key.append(data.draw(st.integers(-shape[i], shape[i]-1))) + + print(f"??? {x.shape = } {len(key) = } {[xp.asarray(k).shape for k in key]}") + + key = tuple(key) + out = x[key] + + arrays = [xp.asarray(k) for k in key] + bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays]) + bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays] + + for idx in sh.ndindex(bcast_shape): + tpl = tuple(k[idx] for k in bcast_key) + assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }" + + def make_scalar_casting_param( method_name: str, dtype: DataType, stype: ScalarType ) -> Param: