Skip to content

Commit 0a4813e

Browse files
committed
Allow iteration on 1-D arrays
upstream code (see scipy/scipy#21074). The standard does not define iteration, but it does define __getitem__, and the default Python __iter__ implements iteration when getitem is defined as a[0], a[1], ..., implying that iteration ought to work for 1-D arrays. Iteration is still disallowed for higher dimensional arrays, since getitem would not necessarily work with a single integer index (and this is the case that is controversial). In those cases, the new unstack() function would be preferable. At best it would be good to get upstream clarification from the standard whether iteration should always work or not before disallowing 1-D array iteration.
1 parent 1edb7b0 commit 0a4813e

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

array_api_strict/_array_object.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -677,10 +677,16 @@ def __iter__(self: Array, /):
677677
"""
678678
Performs the operation __iter__.
679679
"""
680-
# Manually disable iteration, since __getitem__ raises IndexError on
681-
# things like ones((3, 3))[0], which causes list(ones((3, 3))) to give
682-
# [].
683-
raise TypeError("array iteration is not allowed in array-api-strict")
680+
# Manually disable iteration on higher dimensional arrays, since
681+
# __getitem__ raises IndexError on things like ones((3, 3))[0], which
682+
# causes list(ones((3, 3))) to give [].
683+
if self.ndim > 1:
684+
raise TypeError("array iteration is not allowed in array-api-strict")
685+
# Allow iteration for 1-D arrays. The array API doesn't strictly
686+
# define __iter__, but it doesn't disallow it. The default Python
687+
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
688+
# implemented, which implies iteration on 1-D arrays.
689+
return (Array._new(i) for i in self._array)
684690

685691
def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
686692
"""

array_api_strict/tests/test_array_object.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
from builtins import all as all_
23

34
from numpy.testing import assert_raises, suppress_warnings
45
import numpy as np
@@ -21,6 +22,7 @@
2122
int32,
2223
int64,
2324
uint64,
25+
float64,
2426
bool as bool_,
2527
)
2628
from .._flags import set_array_api_strict_flags
@@ -423,8 +425,12 @@ def test_array_namespace():
423425
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
424426
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))
425427

426-
def test_no_iter():
427-
pytest.raises(TypeError, lambda: iter(ones(3)))
428+
def test_iter():
429+
pytest.raises(TypeError, lambda: iter(asarray(3)))
430+
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
431+
assert all_(isinstance(a, Array) for a in iter(ones(3)))
432+
assert all_(a.shape == () for a in iter(ones(3)))
433+
assert all_(a.dtype == float64 for a in iter(ones(3)))
428434
pytest.raises(TypeError, lambda: iter(ones((3, 3))))
429435

430436
@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])

0 commit comments

Comments
 (0)