Skip to content

Allow iteration on 1-D arrays #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,16 @@ def __iter__(self: Array, /):
"""
Performs the operation __iter__.
"""
# Manually disable iteration, since __getitem__ raises IndexError on
# things like ones((3, 3))[0], which causes list(ones((3, 3))) to give
# [].
raise TypeError("array iteration is not allowed in array-api-strict")
# Manually disable iteration on higher dimensional arrays, since
# __getitem__ raises IndexError on things like ones((3, 3))[0], which
# causes list(ones((3, 3))) to give [].
if self.ndim > 1:
raise TypeError("array iteration is not allowed in array-api-strict")
# Allow iteration for 1-D arrays. The array API doesn't strictly
# define __iter__, but it doesn't disallow it. The default Python
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
# implemented, which implies iteration on 1-D arrays.
return (Array._new(i) for i in self._array)

def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Expand Down
10 changes: 8 additions & 2 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from builtins import all as all_

from numpy.testing import assert_raises, suppress_warnings
import numpy as np
Expand All @@ -21,6 +22,7 @@
int32,
int64,
uint64,
float64,
bool as bool_,
)
from .._flags import set_array_api_strict_flags
Expand Down Expand Up @@ -423,8 +425,12 @@ def test_array_namespace():
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2021.11"))
pytest.raises(ValueError, lambda: a.__array_namespace__(api_version="2024.12"))

def test_no_iter():
pytest.raises(TypeError, lambda: iter(ones(3)))
def test_iter():
pytest.raises(TypeError, lambda: iter(asarray(3)))
assert list(ones(3)) == [asarray(1.), asarray(1.), asarray(1.)]
assert all_(isinstance(a, Array) for a in iter(ones(3)))
assert all_(a.shape == () for a in iter(ones(3)))
assert all_(a.dtype == float64 for a in iter(ones(3)))
pytest.raises(TypeError, lambda: iter(ones((3, 3))))

@pytest.mark.parametrize("api_version", ['2021.12', '2022.12', '2023.12'])
Expand Down
9 changes: 9 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

### 2.0.1 (2024-07-01)

## Minor Changes

- Re-allow iteration on 1-D arrays. A change from 2.0 fixed iter() raising on
n-D arrays but also made 1-D arrays raise. The standard does not explicitly
disallow iteration on 1-D arrays, and the default Python `__iter__`
implementation allows it to work, so for now, it is kept intact as working.

## 2.0 (2024-06-27)

### Major Changes
Expand Down
Loading