Skip to content

Commit 18096a5

Browse files
committed
Self-review
1 parent ca147d5 commit 18096a5

File tree

2 files changed

+140
-28
lines changed

2 files changed

+140
-28
lines changed

src/array_api_extra/_funcs.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def _common(
683683
xp: ModuleType | None = None,
684684
_is_update: bool = True,
685685
**kwargs: Untyped,
686-
) -> tuple[Untyped, None] | tuple[None, Array]:
686+
) -> tuple[Array, None] | tuple[None, Array]:
687687
"""Perform common prepocessing.
688688
689689
Returns
@@ -704,16 +704,22 @@ def _common(
704704

705705
x = self.x
706706

707+
if copy not in (True, False, None):
708+
msg = f"copy must be True, False, or None; got {copy!r}" # pyright: ignore[reportUnreachable]
709+
raise ValueError(msg)
710+
707711
if copy is None:
708712
writeable = is_writeable_array(x)
709713
copy = _is_update and not writeable
710714
elif copy:
711715
writeable = None
712-
else:
716+
elif _is_update:
713717
writeable = is_writeable_array(x)
714718
if not writeable:
715719
msg = "Cannot modify parameter in place"
716720
raise ValueError(msg)
721+
else:
722+
writeable = None
717723

718724
if copy:
719725
try:
@@ -723,10 +729,10 @@ def _common(
723729
# with a copy followed by an update
724730
if xp is None:
725731
xp = array_namespace(x)
726-
# Create writeable copy of read-only numpy array
727732
x = xp.asarray(x, copy=True)
728733
if writeable is False:
729734
# A copy of a read-only numpy array is writeable
735+
# Note: this assumes that a copy of a writeable array is writeable
730736
writeable = None
731737
else:
732738
# Use JAX's at[] or other library that with the same duck-type API
@@ -743,12 +749,18 @@ def _common(
743749

744750
return None, x
745751

746-
def get(self, **kwargs: Untyped) -> Untyped:
752+
def get(
753+
self,
754+
/,
755+
copy: bool | None = True,
756+
xp: ModuleType | None = None,
757+
**kwargs: Untyped,
758+
) -> Untyped:
747759
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
748760
that the output is either a copy or a view; it also allows passing
749761
keyword arguments to the backend.
750762
"""
751-
if kwargs.get("copy") is False:
763+
if copy is False:
752764
if is_array_api_obj(self.idx):
753765
# Boolean index. Note that the array API spec
754766
# https://data-apis.org/array-api/latest/API_specification/indexing.html
@@ -758,19 +770,38 @@ def get(self, **kwargs: Untyped) -> Untyped:
758770
# which can be caught by testing the user code vs. array-api-strict.
759771
msg = "get() with an array index always returns a copy"
760772
raise ValueError(msg)
773+
774+
# Prevent scalar indices together with copy=False.
775+
# Even if some backends may return a scalar view of the original, we chose to be
776+
# strict here beceause some other backends, such as numpy, definitely don't.
777+
tup_idx = self.idx if isinstance(self.idx, tuple) else (self.idx,)
778+
if any(
779+
i is not None and i is not Ellipsis and not isinstance(i, slice)
780+
for i in tup_idx
781+
):
782+
msg = "get() with a scalar index typically returns a copy"
783+
raise ValueError(msg)
784+
761785
if is_dask_array(self.x):
762786
msg = "get() on Dask arrays always returns a copy"
763787
raise ValueError(msg)
764788

765-
res, x = self._common("get", _is_update=False, **kwargs)
789+
res, x = self._common("get", copy=copy, xp=xp, _is_update=False, **kwargs)
766790
if res is not None:
767791
return res
768792
assert x is not None
769793
return x[self.idx]
770794

771-
def set(self, y: Array, /, **kwargs: Untyped) -> Array:
795+
def set(
796+
self,
797+
y: Array,
798+
/,
799+
copy: bool | None = True,
800+
xp: ModuleType | None = None,
801+
**kwargs: Untyped,
802+
) -> Array:
772803
"""Apply ``x[idx] = y`` and return the update array"""
773-
res, x = self._common("set", y, **kwargs)
804+
res, x = self._common("set", y, copy=copy, xp=xp, **kwargs)
774805
if res is not None:
775806
return res
776807
assert x is not None
@@ -785,6 +816,8 @@ def _iop(
785816
elwise_op: Callable[[Array, Array], Array],
786817
y: Array,
787818
/,
819+
copy: bool | None = True,
820+
xp: ModuleType | None = None,
788821
**kwargs: Untyped,
789822
) -> Array:
790823
"""x[idx] += y or equivalent in-place operation on a subset of x
@@ -796,41 +829,92 @@ def _iop(
796829
Consider for example when x is a numpy array and idx is a fancy index, which
797830
triggers a deep copy on __getitem__.
798831
"""
799-
res, x = self._common(at_op, y, **kwargs)
832+
res, x = self._common(at_op, y, copy=copy, xp=xp, **kwargs)
800833
if res is not None:
801834
return res
802835
assert x is not None
803836
x[self.idx] = elwise_op(x[self.idx], y)
804837
return x
805838

806-
def add(self, y: Array, /, **kwargs: Untyped) -> Array:
839+
def add(
840+
self,
841+
y: Array,
842+
/,
843+
copy: bool | None = True,
844+
xp: ModuleType | None = None,
845+
**kwargs: Untyped,
846+
) -> Array:
807847
"""Apply ``x[idx] += y`` and return the updated array"""
808-
return self._iop("add", operator.add, y, **kwargs)
848+
return self._iop("add", operator.add, y, copy=copy, xp=xp, **kwargs)
809849

810-
def subtract(self, y: Array, /, **kwargs: Untyped) -> Array:
850+
def subtract(
851+
self,
852+
y: Array,
853+
/,
854+
copy: bool | None = True,
855+
xp: ModuleType | None = None,
856+
**kwargs: Untyped,
857+
) -> Array:
811858
"""Apply ``x[idx] -= y`` and return the updated array"""
812-
return self._iop("subtract", operator.sub, y, **kwargs)
859+
return self._iop("subtract", operator.sub, y, copy=copy, xp=xp, **kwargs)
813860

814-
def multiply(self, y: Array, /, **kwargs: Untyped) -> Array:
861+
def multiply(
862+
self,
863+
y: Array,
864+
/,
865+
copy: bool | None = True,
866+
xp: ModuleType | None = None,
867+
**kwargs: Untyped,
868+
) -> Array:
815869
"""Apply ``x[idx] *= y`` and return the updated array"""
816-
return self._iop("multiply", operator.mul, y, **kwargs)
870+
return self._iop("multiply", operator.mul, y, copy=copy, xp=xp, **kwargs)
817871

818-
def divide(self, y: Array, /, **kwargs: Untyped) -> Array:
872+
def divide(
873+
self,
874+
y: Array,
875+
/,
876+
copy: bool | None = True,
877+
xp: ModuleType | None = None,
878+
**kwargs: Untyped,
879+
) -> Array:
819880
"""Apply ``x[idx] /= y`` and return the updated array"""
820-
return self._iop("divide", operator.truediv, y, **kwargs)
881+
return self._iop("divide", operator.truediv, y, copy=copy, xp=xp, **kwargs)
821882

822-
def power(self, y: Array, /, **kwargs: Untyped) -> Array:
883+
def power(
884+
self,
885+
y: Array,
886+
/,
887+
copy: bool | None = True,
888+
xp: ModuleType | None = None,
889+
**kwargs: Untyped,
890+
) -> Array:
823891
"""Apply ``x[idx] **= y`` and return the updated array"""
824-
return self._iop("power", operator.pow, y, **kwargs)
892+
return self._iop("power", operator.pow, y, copy=copy, xp=xp, **kwargs)
825893

826-
def min(self, y: Array, /, **kwargs: Untyped) -> Array:
894+
def min(
895+
self,
896+
y: Array,
897+
/,
898+
copy: bool | None = True,
899+
xp: ModuleType | None = None,
900+
**kwargs: Untyped,
901+
) -> Array:
827902
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
828-
xp = array_namespace(self.x)
903+
if xp is None:
904+
xp = array_namespace(self.x)
829905
y = xp.asarray(y)
830-
return self._iop("min", xp.minimum, y, **kwargs)
906+
return self._iop("min", xp.minimum, y, copy=copy, xp=xp, **kwargs)
831907

832-
def max(self, y: Array, /, **kwargs: Untyped) -> Array:
908+
def max(
909+
self,
910+
y: Array,
911+
/,
912+
copy: bool | None = True,
913+
xp: ModuleType | None = None,
914+
**kwargs: Untyped,
915+
) -> Array:
833916
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
834-
xp = array_namespace(self.x)
917+
if xp is None:
918+
xp = array_namespace(self.x)
835919
y = xp.asarray(y)
836-
return self._iop("max", xp.maximum, y, **kwargs)
920+
return self._iop("max", xp.maximum, y, copy=copy, xp=xp, **kwargs)

tests/test_at.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from array_api_compat import ( # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
1010
array_namespace,
1111
is_dask_array,
12+
is_numpy_array,
1213
is_pydata_sparse_array,
1314
is_writeable_array,
1415
)
@@ -110,6 +111,14 @@ def test_get(array: Array, copy: bool | None):
110111
return
111112
expect_copy = True
112113

114+
# get(copy=False) on a read-only numpy array returns a read-only view
115+
if is_numpy_array(array) and not copy and not array.flags.writeable:
116+
out = at(array, slice(2)).get(copy=copy)
117+
assert_array_equal(out, [10.0, 20.0])
118+
assert out.base is array
119+
assert not out.flags.writeable
120+
return
121+
113122
with assert_copy(array, expect_copy):
114123
y = at(array, slice(2)).get(copy=copy)
115124
assert isinstance(y, type(array))
@@ -119,6 +128,18 @@ def test_get(array: Array, copy: bool | None):
119128
y[:] = 40
120129

121130

131+
def test_get_scalar_nocopy(array: Array):
132+
"""get(copy=False) with a scalar index always raises, because some backends
133+
such as numpy and sparse return a np.generic instead of a scalar view
134+
"""
135+
with pytest.raises(ValueError, match="scalar"):
136+
at(array)[0].get(copy=False)
137+
with pytest.raises(ValueError, match="scalar"):
138+
at(array)[(0, )].get(copy=False)
139+
with pytest.raises(ValueError, match="scalar"):
140+
at(array)[..., 0].get(copy=False)
141+
142+
122143
def test_get_bool_indices(array: Array):
123144
"""get() with a boolean array index always returns a copy"""
124145
# sparse violates the array API as it doesn't support
@@ -146,10 +167,17 @@ def test_get_bool_indices(array: Array):
146167
def test_copy_invalid():
147168
a = np.asarray([1, 2, 3])
148169
with pytest.raises(ValueError, match="copy"):
149-
at(a, 0).set(4, copy="invalid")
170+
at(a, 0).set(4, copy="invalid") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
150171

151172

152173
def test_xp():
153174
a = np.asarray([1, 2, 3])
154-
b = at(a, 0).set(4, xp=np)
155-
assert_array_equal(b, [4, 2, 3])
175+
at(a, 0).get(xp=np)
176+
at(a, 0).set(4, xp=np)
177+
at(a, 0).add(4, xp=np)
178+
at(a, 0).subtract(4, xp=np)
179+
at(a, 0).multiply(4, xp=np)
180+
at(a, 0).divide(4, xp=np)
181+
at(a, 0).power(4, xp=np)
182+
at(a, 0).min(4, xp=np)
183+
at(a, 0).max(4, xp=np)

0 commit comments

Comments
 (0)