diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index cc6bd1a..d8ed018 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -677,10 +677,16 @@ def __iter__(self: Array, /): """ Performs the operation __iter__. """ - # Manually disable iteration, since __getitem__ raises IndexError on - # things like ones((3, 3))[0], which causes list(ones((3, 3))) to give - # []. - raise TypeError("array iteration is not allowed in array-api-strict") + # Manually disable iteration on higher dimensional arrays, since + # __getitem__ raises IndexError on things like ones((3, 3))[0], which + # causes list(ones((3, 3))) to give []. + if self.ndim > 1: + raise TypeError("array iteration is not allowed in array-api-strict") + # Allow iteration for 1-D arrays. The array API doesn't strictly + # define __iter__, but it doesn't disallow it. The default Python + # behavior is to implement iter as a[0], a[1], ... when __getitem__ is + # implemented, which implies iteration on 1-D arrays. + return (Array._new(i) for i in self._array) def __le__(self: Array, other: Union[int, float, Array], /) -> Array: """ diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index b28c747..b0d4868 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,4 +1,5 @@ import operator +from builtins import all as all_ from numpy.testing import assert_raises, suppress_warnings import numpy as np @@ -21,6 +22,7 @@ int32, int64, uint64, + float64, bool as bool_, ) from .._flags import set_array_api_strict_flags @@ -423,8 +425,12 @@ def test_array_namespace(): pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11")) pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12")) -def test_no_iter(): - pytest.raises(TypeError, lambda: iter(ones(3))) +def test_iter(): + pytest.raises(TypeError, lambda: iter(asarray(3))) + assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)] + assert all_(isinstance(a, Array) for a in iter(ones(3))) + assert all_(a.shape == () for a in iter(ones(3))) + assert all_(a.dtype == float64 for a in iter(ones(3))) pytest.raises(TypeError, lambda: iter(ones((3, 3)))) @pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12']) diff --git a/docs/changelog.md b/docs/changelog.md index 9c5da3c..7b40fe3 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,14 @@ # Changelog +### 2.0.1 (2024-07-01) + +## Minor Changes + +- Re-allow iteration on 1-D arrays. A change from 2.0 fixed iter() raising on + n-D arrays but also made 1-D arrays raise. The standard does not explicitly + disallow iteration on 1-D arrays, and the default Python `__iter__` + implementation allows it to work, so for now, it is kept intact as working. + ## 2.0 (2024-06-27) ### Major Changes