@@ -185,22 +185,42 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
185
185
raise ValueError (msg )
186
186
return at (self ._x , idx )
187
187
188
- def _update_common (
188
+ def _op (
189
189
self ,
190
190
at_op : _AtOp ,
191
- y : Array ,
191
+ in_place_op : Callable [[Array , Array | object ], Array ] | None ,
192
+ y : Array | object ,
192
193
/ ,
193
194
copy : bool | None ,
194
195
xp : ModuleType | None ,
195
- ) -> tuple [ Array , None ] | tuple [ None , Array ]: # numpydoc ignore=PR01
196
+ ) -> Array :
196
197
"""
197
- Perform common prepocessing to all update operations.
198
+ Implement all update operations.
199
+
200
+ Parameters
201
+ ----------
202
+ at_op : _AtOp
203
+ Method of JAX's Array.at[].
204
+ in_place_op : Callable[[Array, Array | object], Array] | None
205
+ In-place operation to apply on mutable backends::
206
+
207
+ x[idx] = in_place_op(x[idx], y)
208
+
209
+ If None::
210
+
211
+ x[idx] = y
212
+
213
+ y : array or object
214
+ Right-hand side of the operation.
215
+ copy : bool or None
216
+ Whether to copy the input array. See the class docstring for details.
217
+ xp : array_namespace or None
218
+ The array namespace for the input array.
198
219
199
220
Returns
200
221
-------
201
- tuple
202
- If the operation can be resolved by ``at[]``, ``(return value, None)``
203
- Otherwise, ``(None, preprocessed x)``.
222
+ Array
223
+ Updated `x`.
204
224
"""
205
225
x , idx = self ._x , self ._idx
206
226
@@ -231,7 +251,7 @@ def _update_common(
231
251
if is_jax_array (x ):
232
252
# Use JAX's at[]
233
253
func = cast (Callable [[Array ], Array ], getattr (x .at [idx ], at_op .value ))
234
- return func (y ), None
254
+ return func (y )
235
255
# Emulate at[] behaviour for non-JAX arrays
236
256
# with a copy followed by an update
237
257
if xp is None :
@@ -249,52 +269,25 @@ def _update_common(
249
269
msg = f"Can't update read-only array { x } "
250
270
raise ValueError (msg )
251
271
252
- return None , x
272
+ if in_place_op :
273
+ x [self ._idx ] = in_place_op (x [self ._idx ], y )
274
+ else : # set()
275
+ x [self ._idx ] = y
276
+ return x
253
277
254
278
def set (
255
279
self ,
256
- y : Array ,
280
+ y : Array | object ,
257
281
/ ,
258
282
copy : bool | None = None ,
259
283
xp : ModuleType | None = None ,
260
284
) -> Array : # numpydoc ignore=PR01,RT01
261
285
"""Apply ``x[idx] = y`` and return the update array."""
262
- res , x = self ._update_common (_AtOp .SET , y , copy = copy , xp = xp )
263
- if res is not None :
264
- return res
265
- assert x is not None
266
- x [self ._idx ] = y
267
- return x
268
-
269
- def _iop (
270
- self ,
271
- at_op : _AtOp ,
272
- elwise_op : Callable [[Array , Array ], Array ],
273
- y : Array ,
274
- / ,
275
- copy : bool | None ,
276
- xp : ModuleType | None ,
277
- ) -> Array : # numpydoc ignore=PR01,RT01
278
- """
279
- ``x[idx] += y`` or equivalent in-place operation on a subset of x.
280
-
281
- which is the same as saying
282
- x[idx] = x[idx] + y
283
- Note that this is not the same as
284
- operator.iadd(x[idx], y)
285
- Consider for example when x is a numpy array and idx is a fancy index, which
286
- triggers a deep copy on __getitem__.
287
- """
288
- res , x = self ._update_common (at_op , y , copy = copy , xp = xp )
289
- if res is not None :
290
- return res
291
- assert x is not None
292
- x [self ._idx ] = elwise_op (x [self ._idx ], y )
293
- return x
286
+ return self ._op (_AtOp .SET , None , y , copy = copy , xp = xp )
294
287
295
288
def add (
296
289
self ,
297
- y : Array ,
290
+ y : Array | object ,
298
291
/ ,
299
292
copy : bool | None = None ,
300
293
xp : ModuleType | None = None ,
@@ -304,70 +297,68 @@ def add(
304
297
# Note for this and all other methods based on _iop:
305
298
# operator.iadd and operator.add subtly differ in behaviour, as
306
299
# only iadd will trigger exceptions when y has an incompatible dtype.
307
- return self ._iop (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
300
+ return self ._op (_AtOp .ADD , operator .iadd , y , copy = copy , xp = xp )
308
301
309
302
def subtract (
310
303
self ,
311
- y : Array ,
304
+ y : Array | object ,
312
305
/ ,
313
306
copy : bool | None = None ,
314
307
xp : ModuleType | None = None ,
315
308
) -> Array : # numpydoc ignore=PR01,RT01
316
309
"""Apply ``x[idx] -= y`` and return the updated array."""
317
- return self ._iop (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
310
+ return self ._op (_AtOp .SUBTRACT , operator .isub , y , copy = copy , xp = xp )
318
311
319
312
def multiply (
320
313
self ,
321
- y : Array ,
314
+ y : Array | object ,
322
315
/ ,
323
316
copy : bool | None = None ,
324
317
xp : ModuleType | None = None ,
325
318
) -> Array : # numpydoc ignore=PR01,RT01
326
319
"""Apply ``x[idx] *= y`` and return the updated array."""
327
- return self ._iop (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
320
+ return self ._op (_AtOp .MULTIPLY , operator .imul , y , copy = copy , xp = xp )
328
321
329
322
def divide (
330
323
self ,
331
- y : Array ,
324
+ y : Array | object ,
332
325
/ ,
333
326
copy : bool | None = None ,
334
327
xp : ModuleType | None = None ,
335
328
) -> Array : # numpydoc ignore=PR01,RT01
336
329
"""Apply ``x[idx] /= y`` and return the updated array."""
337
- return self ._iop (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
330
+ return self ._op (_AtOp .DIVIDE , operator .itruediv , y , copy = copy , xp = xp )
338
331
339
332
def power (
340
333
self ,
341
- y : Array ,
334
+ y : Array | object ,
342
335
/ ,
343
336
copy : bool | None = None ,
344
337
xp : ModuleType | None = None ,
345
338
) -> Array : # numpydoc ignore=PR01,RT01
346
339
"""Apply ``x[idx] **= y`` and return the updated array."""
347
- return self ._iop (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
340
+ return self ._op (_AtOp .POWER , operator .ipow , y , copy = copy , xp = xp )
348
341
349
342
def min (
350
343
self ,
351
- y : Array ,
344
+ y : Array | object ,
352
345
/ ,
353
346
copy : bool | None = None ,
354
347
xp : ModuleType | None = None ,
355
348
) -> Array : # numpydoc ignore=PR01,RT01
356
349
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
357
- if xp is None :
358
- xp = array_namespace (self ._x )
350
+ xp = array_namespace (self ._x ) if xp is None else xp
359
351
y = xp .asarray (y )
360
- return self ._iop (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
352
+ return self ._op (_AtOp .MIN , xp .minimum , y , copy = copy , xp = xp )
361
353
362
354
def max (
363
355
self ,
364
- y : Array ,
356
+ y : Array | object ,
365
357
/ ,
366
358
copy : bool | None = None ,
367
359
xp : ModuleType | None = None ,
368
360
) -> Array : # numpydoc ignore=PR01,RT01
369
361
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
370
- if xp is None :
371
- xp = array_namespace (self ._x )
362
+ xp = array_namespace (self ._x ) if xp is None else xp
372
363
y = xp .asarray (y )
373
- return self ._iop (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
364
+ return self ._op (_AtOp .MAX , xp .maximum , y , copy = copy , xp = xp )
0 commit comments