Skip to content

Commit b5bf75c

Browse files
authored
Merge pull request #118 from crusaderky/simple_at
2 parents 6ee70c0 + 6f0ef5c commit b5bf75c

File tree

1 file changed

+51
-60
lines changed
  • src/array_api_extra/_lib

1 file changed

+51
-60
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
185185
raise ValueError(msg)
186186
return at(self._x, idx)
187187

188-
def _update_common(
188+
def _op(
189189
self,
190190
at_op: _AtOp,
191-
y: Array,
191+
in_place_op: Callable[[Array, Array | object], Array] | None,
192+
y: Array | object,
192193
/,
193194
copy: bool | None,
194195
xp: ModuleType | None,
195-
) -> tuple[Array, None] | tuple[None, Array]: # numpydoc ignore=PR01
196+
) -> Array:
196197
"""
197-
Perform common prepocessing to all update operations.
198+
Implement all update operations.
199+
200+
Parameters
201+
----------
202+
at_op : _AtOp
203+
Method of JAX's Array.at[].
204+
in_place_op : Callable[[Array, Array | object], Array] | None
205+
In-place operation to apply on mutable backends::
206+
207+
x[idx] = in_place_op(x[idx], y)
208+
209+
If None::
210+
211+
x[idx] = y
212+
213+
y : array or object
214+
Right-hand side of the operation.
215+
copy : bool or None
216+
Whether to copy the input array. See the class docstring for details.
217+
xp : array_namespace or None
218+
The array namespace for the input array.
198219
199220
Returns
200221
-------
201-
tuple
202-
If the operation can be resolved by ``at[]``, ``(return value, None)``
203-
Otherwise, ``(None, preprocessed x)``.
222+
Array
223+
Updated `x`.
204224
"""
205225
x, idx = self._x, self._idx
206226

@@ -231,7 +251,7 @@ def _update_common(
231251
if is_jax_array(x):
232252
# Use JAX's at[]
233253
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
234-
return func(y), None
254+
return func(y)
235255
# Emulate at[] behaviour for non-JAX arrays
236256
# with a copy followed by an update
237257
if xp is None:
@@ -249,52 +269,25 @@ def _update_common(
249269
msg = f"Can't update read-only array {x}"
250270
raise ValueError(msg)
251271

252-
return None, x
272+
if in_place_op:
273+
x[self._idx] = in_place_op(x[self._idx], y)
274+
else: # set()
275+
x[self._idx] = y
276+
return x
253277

254278
def set(
255279
self,
256-
y: Array,
280+
y: Array | object,
257281
/,
258282
copy: bool | None = None,
259283
xp: ModuleType | None = None,
260284
) -> Array: # numpydoc ignore=PR01,RT01
261285
"""Apply ``x[idx] = y`` and return the update array."""
262-
res, x = self._update_common(_AtOp.SET, y, copy=copy, xp=xp)
263-
if res is not None:
264-
return res
265-
assert x is not None
266-
x[self._idx] = y
267-
return x
268-
269-
def _iop(
270-
self,
271-
at_op: _AtOp,
272-
elwise_op: Callable[[Array, Array], Array],
273-
y: Array,
274-
/,
275-
copy: bool | None,
276-
xp: ModuleType | None,
277-
) -> Array: # numpydoc ignore=PR01,RT01
278-
"""
279-
``x[idx] += y`` or equivalent in-place operation on a subset of x.
280-
281-
which is the same as saying
282-
x[idx] = x[idx] + y
283-
Note that this is not the same as
284-
operator.iadd(x[idx], y)
285-
Consider for example when x is a numpy array and idx is a fancy index, which
286-
triggers a deep copy on __getitem__.
287-
"""
288-
res, x = self._update_common(at_op, y, copy=copy, xp=xp)
289-
if res is not None:
290-
return res
291-
assert x is not None
292-
x[self._idx] = elwise_op(x[self._idx], y)
293-
return x
286+
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)
294287

295288
def add(
296289
self,
297-
y: Array,
290+
y: Array | object,
298291
/,
299292
copy: bool | None = None,
300293
xp: ModuleType | None = None,
@@ -304,70 +297,68 @@ def add(
304297
# Note for this and all other methods based on _iop:
305298
# operator.iadd and operator.add subtly differ in behaviour, as
306299
# only iadd will trigger exceptions when y has an incompatible dtype.
307-
return self._iop(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
300+
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
308301

309302
def subtract(
310303
self,
311-
y: Array,
304+
y: Array | object,
312305
/,
313306
copy: bool | None = None,
314307
xp: ModuleType | None = None,
315308
) -> Array: # numpydoc ignore=PR01,RT01
316309
"""Apply ``x[idx] -= y`` and return the updated array."""
317-
return self._iop(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
310+
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
318311

319312
def multiply(
320313
self,
321-
y: Array,
314+
y: Array | object,
322315
/,
323316
copy: bool | None = None,
324317
xp: ModuleType | None = None,
325318
) -> Array: # numpydoc ignore=PR01,RT01
326319
"""Apply ``x[idx] *= y`` and return the updated array."""
327-
return self._iop(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
320+
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
328321

329322
def divide(
330323
self,
331-
y: Array,
324+
y: Array | object,
332325
/,
333326
copy: bool | None = None,
334327
xp: ModuleType | None = None,
335328
) -> Array: # numpydoc ignore=PR01,RT01
336329
"""Apply ``x[idx] /= y`` and return the updated array."""
337-
return self._iop(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
330+
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
338331

339332
def power(
340333
self,
341-
y: Array,
334+
y: Array | object,
342335
/,
343336
copy: bool | None = None,
344337
xp: ModuleType | None = None,
345338
) -> Array: # numpydoc ignore=PR01,RT01
346339
"""Apply ``x[idx] **= y`` and return the updated array."""
347-
return self._iop(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
340+
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
348341

349342
def min(
350343
self,
351-
y: Array,
344+
y: Array | object,
352345
/,
353346
copy: bool | None = None,
354347
xp: ModuleType | None = None,
355348
) -> Array: # numpydoc ignore=PR01,RT01
356349
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
357-
if xp is None:
358-
xp = array_namespace(self._x)
350+
xp = array_namespace(self._x) if xp is None else xp
359351
y = xp.asarray(y)
360-
return self._iop(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
352+
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
361353

362354
def max(
363355
self,
364-
y: Array,
356+
y: Array | object,
365357
/,
366358
copy: bool | None = None,
367359
xp: ModuleType | None = None,
368360
) -> Array: # numpydoc ignore=PR01,RT01
369361
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
370-
if xp is None:
371-
xp = array_namespace(self._x)
362+
xp = array_namespace(self._x) if xp is None else xp
372363
y = xp.asarray(y)
373-
return self._iop(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
364+
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)

0 commit comments

Comments
 (0)