Skip to content

Commit 10996a3

Browse files
authored
MAINT: various refactoring (#101)
1 parent 84b540c commit 10996a3

File tree

8 files changed

+1550
-1374
lines changed

8 files changed

+1550
-1374
lines changed

pixi.lock

Lines changed: 1018 additions & 853 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ cupy = "*"
167167

168168
[tool.pixi.environments]
169169
default = { solve-group = "default" }
170-
lint = { features = ["lint"], solve-group = "default" }
170+
lint = { features = ["lint", "backends"], solve-group = "default" }
171171
tests = { features = ["tests"], solve-group = "default" }
172172
docs = { features = ["docs"], solve-group = "default" }
173173
dev = { features = ["lint", "tests", "docs", "dev", "backends"], solve-group = "default" }

src/array_api_extra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Extra array functions built on top of the array API standard."""
22

33
from ._delegation import pad
4+
from ._lib._at import at
45
from ._lib._funcs import (
5-
at,
66
atleast_nd,
77
cov,
88
create_diagonal,

src/array_api_extra/_lib/_at.py

Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
"""Update operations for read-only arrays."""
2+
3+
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
4+
from __future__ import annotations
5+
6+
import operator
7+
from collections.abc import Callable
8+
from enum import Enum
9+
from types import ModuleType
10+
from typing import ClassVar, cast
11+
12+
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
13+
from ._utils._typing import Array, Index
14+
15+
16+
class _AtOp(Enum):
17+
"""Operations for use in `xpx.at`."""
18+
19+
SET = "set"
20+
ADD = "add"
21+
SUBTRACT = "subtract"
22+
MULTIPLY = "multiply"
23+
DIVIDE = "divide"
24+
POWER = "power"
25+
MIN = "min"
26+
MAX = "max"
27+
28+
# @override from Python 3.12
29+
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
30+
"""
31+
Return string representation (useful for pytest logs).
32+
33+
Returns
34+
-------
35+
str
36+
The operation's name.
37+
"""
38+
return self.value
39+
40+
41+
_undef = object()
42+
43+
44+
class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
45+
"""
46+
Update operations for read-only arrays.
47+
48+
This implements ``jax.numpy.ndarray.at`` for all writeable
49+
backends (those that support ``__setitem__``) and routes
50+
to the ``.at[]`` method for JAX arrays.
51+
52+
Parameters
53+
----------
54+
x : array
55+
Input array.
56+
idx : index, optional
57+
Only `array API standard compliant indices
58+
<https://data-apis.org/array-api/latest/API_specification/indexing.html>`_
59+
are supported.
60+
61+
You may use two alternate syntaxes::
62+
63+
>>> import array_api_extra as xpx
64+
>>> xpx.at(x, idx).set(value) # or add(value), etc.
65+
>>> xpx.at(x)[idx].set(value)
66+
67+
copy : bool, optional
68+
None (default)
69+
The array parameter *may* be modified in place if it is
70+
possible and beneficial for performance.
71+
You should not reuse it after calling this function.
72+
True
73+
Ensure that the inputs are not modified.
74+
False
75+
Ensure that the update operation writes back to the input.
76+
Raise ``ValueError`` if a copy cannot be avoided.
77+
78+
xp : array_namespace, optional
79+
The standard-compatible namespace for `x`. Default: infer.
80+
81+
Returns
82+
-------
83+
Updated input array.
84+
85+
Warnings
86+
--------
87+
(a) When you omit the ``copy`` parameter, you should never reuse the parameter
88+
array later on; ideally, you should reassign it immediately::
89+
90+
>>> import array_api_extra as xpx
91+
>>> x = xpx.at(x, 0).set(2)
92+
93+
The above best practice pattern ensures that the behaviour won't change depending
94+
on whether ``x`` is writeable or not, as the original ``x`` object is dereferenced
95+
as soon as ``xpx.at`` returns; this way there is no risk to accidentally update it
96+
twice.
97+
98+
On the reverse, the anti-pattern below must be avoided, as it will result in
99+
different behaviour on read-only versus writeable arrays::
100+
101+
>>> x = xp.asarray([0, 0, 0])
102+
>>> y = xpx.at(x, 0).set(2)
103+
>>> z = xpx.at(x, 1).set(3)
104+
105+
In the above example, both calls to ``xpx.at`` update ``x`` in place *if possible*.
106+
This causes the behaviour to diverge depending on whether ``x`` is writeable or not:
107+
108+
- If ``x`` is writeable, then after the snippet above you'll have
109+
``x == y == z == [2, 3, 0]``
110+
- If ``x`` is read-only, then you'll end up with
111+
``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and ``z == [0, 3, 0]``.
112+
113+
The correct pattern to use if you want diverging outputs from the same input is
114+
to enforce copies::
115+
116+
>>> x = xp.asarray([0, 0, 0])
117+
>>> y = xpx.at(x, 0).set(2, copy=True) # Never updates x
118+
>>> z = xpx.at(x, 1).set(3) # May or may not update x in place
119+
>>> del x # avoid accidental reuse of x as we don't know its state anymore
120+
121+
(b) The array API standard does not support integer array indices.
122+
The behaviour of update methods when the index is an array of integers is
123+
undefined and will vary between backends; this is particularly true when the
124+
index contains multiple occurrences of the same index, e.g.::
125+
126+
>>> import numpy as np
127+
>>> import jax.numpy as jnp
128+
>>> import array_api_extra as xpx
129+
>>> xpx.at(np.asarray([123]), np.asarray([0, 0])).add(1)
130+
array([124])
131+
>>> xpx.at(jnp.asarray([123]), jnp.asarray([0, 0])).add(1)
132+
Array([125], dtype=int32)
133+
134+
See Also
135+
--------
136+
jax.numpy.ndarray.at : Equivalent array method in JAX.
137+
138+
Notes
139+
-----
140+
`sparse <https://sparse.pydata.org/>`_, as well as read-only arrays from libraries
141+
not explicitly covered by ``array-api-compat``, are not supported by update
142+
methods.
143+
144+
Examples
145+
--------
146+
Given either of these equivalent expressions::
147+
148+
>>> import array_api_extra as xpx
149+
>>> x = xpx.at(x)[1].add(2)
150+
>>> x = xpx.at(x, 1).add(2)
151+
152+
If x is a JAX array, they are the same as::
153+
154+
>>> x = x.at[1].add(2)
155+
156+
If x is a read-only numpy array, they are the same as::
157+
158+
>>> x = x.copy()
159+
>>> x[1] += 2
160+
161+
For other known backends, they are the same as::
162+
163+
>>> x[1] += 2
164+
"""
165+
166+
_x: Array
167+
_idx: Index
168+
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")
169+
170+
def __init__(
171+
self, x: Array, idx: Index = _undef, /
172+
) -> None: # numpydoc ignore=GL08
173+
self._x = x
174+
self._idx = idx
175+
176+
def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
177+
"""
178+
Allow for the alternate syntax ``at(x)[start:stop:step]``.
179+
180+
It looks prettier than ``at(x, slice(start, stop, step))``
181+
and feels more intuitive coming from the JAX documentation.
182+
"""
183+
if self._idx is not _undef:
184+
msg = "Index has already been set"
185+
raise ValueError(msg)
186+
return at(self._x, idx)
187+
188+
def _update_common(
189+
self,
190+
at_op: _AtOp,
191+
y: Array,
192+
/,
193+
copy: bool | None,
194+
xp: ModuleType | None,
195+
) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01
196+
"""
197+
Perform common prepocessing to all update operations.
198+
199+
Returns
200+
-------
201+
tuple
202+
If the operation can be resolved by ``at[]``, ``(return value, None)``
203+
Otherwise, ``(None, preprocessed x)``.
204+
"""
205+
x, idx = self._x, self._idx
206+
207+
if idx is _undef:
208+
msg = (
209+
"Index has not been set.\n"
210+
"Usage: either\n"
211+
" at(x, idx).set(value)\n"
212+
"or\n"
213+
" at(x)[idx].set(value)\n"
214+
"(same for all other methods)."
215+
)
216+
raise ValueError(msg)
217+
218+
if copy not in (True, False, None):
219+
msg = f"copy must be True, False, or None; got {copy!r}"
220+
raise ValueError(msg)
221+
222+
if copy is None:
223+
writeable = is_writeable_array(x)
224+
copy = not writeable
225+
elif copy:
226+
writeable = None
227+
else:
228+
writeable = is_writeable_array(x)
229+
230+
if copy:
231+
if is_jax_array(x):
232+
# Use JAX's at[]
233+
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
234+
return func(y), None
235+
# Emulate at[] behaviour for non-JAX arrays
236+
# with a copy followed by an update
237+
if xp is None:
238+
xp = array_namespace(x)
239+
x = xp.asarray(x, copy=True)
240+
if writeable is False:
241+
# A copy of a read-only numpy array is writeable
242+
# Note: this assumes that a copy of a writeable array is writeable
243+
writeable = None
244+
245+
if writeable is None:
246+
writeable = is_writeable_array(x)
247+
if not writeable:
248+
# sparse crashes here
249+
msg = f"Can't update read-only array {x}"
250+
raise ValueError(msg)
251+
252+
return None, x
253+
254+
def set(
255+
self,
256+
y: Array,
257+
/,
258+
copy: bool | None = None,
259+
xp: ModuleType | None = None,
260+
) -> Array: # numpydoc ignore=PR01,RT01
261+
"""Apply ``x[idx] = y`` and return the update array."""
262+
res, x = self._update_common(_AtOp.SET, y, copy=copy, xp=xp)
263+
if res is not None:
264+
return res
265+
assert x is not None
266+
x[self._idx] = y
267+
return x
268+
269+
def _iop(
270+
self,
271+
at_op: _AtOp,
272+
elwise_op: Callable[[Array, Array], Array],
273+
y: Array,
274+
/,
275+
copy: bool | None,
276+
xp: ModuleType | None,
277+
) -> Array: # numpydoc ignore=PR01,RT01
278+
"""
279+
``x[idx] += y`` or equivalent in-place operation on a subset of x.
280+
281+
which is the same as saying
282+
x[idx] = x[idx] + y
283+
Note that this is not the same as
284+
operator.iadd(x[idx], y)
285+
Consider for example when x is a numpy array and idx is a fancy index, which
286+
triggers a deep copy on __getitem__.
287+
"""
288+
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
289+
if res is not None:
290+
return res
291+
assert x is not None
292+
x[self._idx] = elwise_op(x[self._idx], y)
293+
return x
294+
295+
def add(
296+
self,
297+
y: Array,
298+
/,
299+
copy: bool | None = None,
300+
xp: ModuleType | None = None,
301+
) -> Array: # numpydoc ignore=PR01,RT01
302+
"""Apply ``x[idx] += y`` and return the updated array."""
303+
304+
# Note for this and all other methods based on _iop:
305+
# operator.iadd and operator.add subtly differ in behaviour, as
306+
# only iadd will trigger exceptions when y has an incompatible dtype.
307+
return self._iop(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
308+
309+
def subtract(
310+
self,
311+
y: Array,
312+
/,
313+
copy: bool | None = None,
314+
xp: ModuleType | None = None,
315+
) -> Array: # numpydoc ignore=PR01,RT01
316+
"""Apply ``x[idx] -= y`` and return the updated array."""
317+
return self._iop(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
318+
319+
def multiply(
320+
self,
321+
y: Array,
322+
/,
323+
copy: bool | None = None,
324+
xp: ModuleType | None = None,
325+
) -> Array: # numpydoc ignore=PR01,RT01
326+
"""Apply ``x[idx] *= y`` and return the updated array."""
327+
return self._iop(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
328+
329+
def divide(
330+
self,
331+
y: Array,
332+
/,
333+
copy: bool | None = None,
334+
xp: ModuleType | None = None,
335+
) -> Array: # numpydoc ignore=PR01,RT01
336+
"""Apply ``x[idx] /= y`` and return the updated array."""
337+
return self._iop(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
338+
339+
def power(
340+
self,
341+
y: Array,
342+
/,
343+
copy: bool | None = None,
344+
xp: ModuleType | None = None,
345+
) -> Array: # numpydoc ignore=PR01,RT01
346+
"""Apply ``x[idx] **= y`` and return the updated array."""
347+
return self._iop(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
348+
349+
def min(
350+
self,
351+
y: Array,
352+
/,
353+
copy: bool | None = None,
354+
xp: ModuleType | None = None,
355+
) -> Array: # numpydoc ignore=PR01,RT01
356+
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
357+
if xp is None:
358+
xp = array_namespace(self._x)
359+
y = xp.asarray(y)
360+
return self._iop(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
361+
362+
def max(
363+
self,
364+
y: Array,
365+
/,
366+
copy: bool | None = None,
367+
xp: ModuleType | None = None,
368+
) -> Array: # numpydoc ignore=PR01,RT01
369+
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
370+
if xp is None:
371+
xp = array_namespace(self._x)
372+
y = xp.asarray(y)
373+
return self._iop(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)

0 commit comments

Comments
 (0)