@@ -683,7 +683,7 @@ def _common(
683
683
xp : ModuleType | None = None ,
684
684
_is_update : bool = True ,
685
685
** kwargs : Untyped ,
686
- ) -> tuple [Untyped , None ] | tuple [None , Array ]:
686
+ ) -> tuple [Array , None ] | tuple [None , Array ]:
687
687
"""Perform common prepocessing.
688
688
689
689
Returns
@@ -704,16 +704,22 @@ def _common(
704
704
705
705
x = self .x
706
706
707
+ if copy not in (True , False , None ):
708
+ msg = f"copy must be True, False, or None; got { copy !r} " # pyright: ignore[reportUnreachable]
709
+ raise ValueError (msg )
710
+
707
711
if copy is None :
708
712
writeable = is_writeable_array (x )
709
713
copy = _is_update and not writeable
710
714
elif copy :
711
715
writeable = None
712
- else :
716
+ elif _is_update :
713
717
writeable = is_writeable_array (x )
714
718
if not writeable :
715
719
msg = "Cannot modify parameter in place"
716
720
raise ValueError (msg )
721
+ else :
722
+ writeable = None
717
723
718
724
if copy :
719
725
try :
@@ -723,10 +729,10 @@ def _common(
723
729
# with a copy followed by an update
724
730
if xp is None :
725
731
xp = array_namespace (x )
726
- # Create writeable copy of read-only numpy array
727
732
x = xp .asarray (x , copy = True )
728
733
if writeable is False :
729
734
# A copy of a read-only numpy array is writeable
735
+ # Note: this assumes that a copy of a writeable array is writeable
730
736
writeable = None
731
737
else :
732
738
# Use JAX's at[] or other library that with the same duck-type API
@@ -743,12 +749,18 @@ def _common(
743
749
744
750
return None , x
745
751
746
- def get (self , ** kwargs : Untyped ) -> Untyped :
752
+ def get (
753
+ self ,
754
+ / ,
755
+ copy : bool | None = True ,
756
+ xp : ModuleType | None = None ,
757
+ ** kwargs : Untyped ,
758
+ ) -> Untyped :
747
759
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring
748
760
that the output is either a copy or a view; it also allows passing
749
761
keyword arguments to the backend.
750
762
"""
751
- if kwargs . get ( " copy" ) is False :
763
+ if copy is False :
752
764
if is_array_api_obj (self .idx ):
753
765
# Boolean index. Note that the array API spec
754
766
# https://data-apis.org/array-api/latest/API_specification/indexing.html
@@ -758,19 +770,38 @@ def get(self, **kwargs: Untyped) -> Untyped:
758
770
# which can be caught by testing the user code vs. array-api-strict.
759
771
msg = "get() with an array index always returns a copy"
760
772
raise ValueError (msg )
773
+
774
+ # Prevent scalar indices together with copy=False.
775
+ # Even if some backends may return a scalar view of the original, we chose to be
776
+ # strict here beceause some other backends, such as numpy, definitely don't.
777
+ tup_idx = self .idx if isinstance (self .idx , tuple ) else (self .idx ,)
778
+ if any (
779
+ i is not None and i is not Ellipsis and not isinstance (i , slice )
780
+ for i in tup_idx
781
+ ):
782
+ msg = "get() with a scalar index typically returns a copy"
783
+ raise ValueError (msg )
784
+
761
785
if is_dask_array (self .x ):
762
786
msg = "get() on Dask arrays always returns a copy"
763
787
raise ValueError (msg )
764
788
765
- res , x = self ._common ("get" , _is_update = False , ** kwargs )
789
+ res , x = self ._common ("get" , copy = copy , xp = xp , _is_update = False , ** kwargs )
766
790
if res is not None :
767
791
return res
768
792
assert x is not None
769
793
return x [self .idx ]
770
794
771
- def set (self , y : Array , / , ** kwargs : Untyped ) -> Array :
795
+ def set (
796
+ self ,
797
+ y : Array ,
798
+ / ,
799
+ copy : bool | None = True ,
800
+ xp : ModuleType | None = None ,
801
+ ** kwargs : Untyped ,
802
+ ) -> Array :
772
803
"""Apply ``x[idx] = y`` and return the update array"""
773
- res , x = self ._common ("set" , y , ** kwargs )
804
+ res , x = self ._common ("set" , y , copy = copy , xp = xp , ** kwargs )
774
805
if res is not None :
775
806
return res
776
807
assert x is not None
@@ -785,6 +816,8 @@ def _iop(
785
816
elwise_op : Callable [[Array , Array ], Array ],
786
817
y : Array ,
787
818
/ ,
819
+ copy : bool | None = True ,
820
+ xp : ModuleType | None = None ,
788
821
** kwargs : Untyped ,
789
822
) -> Array :
790
823
"""x[idx] += y or equivalent in-place operation on a subset of x
@@ -796,41 +829,92 @@ def _iop(
796
829
Consider for example when x is a numpy array and idx is a fancy index, which
797
830
triggers a deep copy on __getitem__.
798
831
"""
799
- res , x = self ._common (at_op , y , ** kwargs )
832
+ res , x = self ._common (at_op , y , copy = copy , xp = xp , ** kwargs )
800
833
if res is not None :
801
834
return res
802
835
assert x is not None
803
836
x [self .idx ] = elwise_op (x [self .idx ], y )
804
837
return x
805
838
806
- def add (self , y : Array , / , ** kwargs : Untyped ) -> Array :
839
+ def add (
840
+ self ,
841
+ y : Array ,
842
+ / ,
843
+ copy : bool | None = True ,
844
+ xp : ModuleType | None = None ,
845
+ ** kwargs : Untyped ,
846
+ ) -> Array :
807
847
"""Apply ``x[idx] += y`` and return the updated array"""
808
- return self ._iop ("add" , operator .add , y , ** kwargs )
848
+ return self ._iop ("add" , operator .add , y , copy = copy , xp = xp , ** kwargs )
809
849
810
- def subtract (self , y : Array , / , ** kwargs : Untyped ) -> Array :
850
+ def subtract (
851
+ self ,
852
+ y : Array ,
853
+ / ,
854
+ copy : bool | None = True ,
855
+ xp : ModuleType | None = None ,
856
+ ** kwargs : Untyped ,
857
+ ) -> Array :
811
858
"""Apply ``x[idx] -= y`` and return the updated array"""
812
- return self ._iop ("subtract" , operator .sub , y , ** kwargs )
859
+ return self ._iop ("subtract" , operator .sub , y , copy = copy , xp = xp , ** kwargs )
813
860
814
- def multiply (self , y : Array , / , ** kwargs : Untyped ) -> Array :
861
+ def multiply (
862
+ self ,
863
+ y : Array ,
864
+ / ,
865
+ copy : bool | None = True ,
866
+ xp : ModuleType | None = None ,
867
+ ** kwargs : Untyped ,
868
+ ) -> Array :
815
869
"""Apply ``x[idx] *= y`` and return the updated array"""
816
- return self ._iop ("multiply" , operator .mul , y , ** kwargs )
870
+ return self ._iop ("multiply" , operator .mul , y , copy = copy , xp = xp , ** kwargs )
817
871
818
- def divide (self , y : Array , / , ** kwargs : Untyped ) -> Array :
872
+ def divide (
873
+ self ,
874
+ y : Array ,
875
+ / ,
876
+ copy : bool | None = True ,
877
+ xp : ModuleType | None = None ,
878
+ ** kwargs : Untyped ,
879
+ ) -> Array :
819
880
"""Apply ``x[idx] /= y`` and return the updated array"""
820
- return self ._iop ("divide" , operator .truediv , y , ** kwargs )
881
+ return self ._iop ("divide" , operator .truediv , y , copy = copy , xp = xp , ** kwargs )
821
882
822
- def power (self , y : Array , / , ** kwargs : Untyped ) -> Array :
883
+ def power (
884
+ self ,
885
+ y : Array ,
886
+ / ,
887
+ copy : bool | None = True ,
888
+ xp : ModuleType | None = None ,
889
+ ** kwargs : Untyped ,
890
+ ) -> Array :
823
891
"""Apply ``x[idx] **= y`` and return the updated array"""
824
- return self ._iop ("power" , operator .pow , y , ** kwargs )
892
+ return self ._iop ("power" , operator .pow , y , copy = copy , xp = xp , ** kwargs )
825
893
826
- def min (self , y : Array , / , ** kwargs : Untyped ) -> Array :
894
+ def min (
895
+ self ,
896
+ y : Array ,
897
+ / ,
898
+ copy : bool | None = True ,
899
+ xp : ModuleType | None = None ,
900
+ ** kwargs : Untyped ,
901
+ ) -> Array :
827
902
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array"""
828
- xp = array_namespace (self .x )
903
+ if xp is None :
904
+ xp = array_namespace (self .x )
829
905
y = xp .asarray (y )
830
- return self ._iop ("min" , xp .minimum , y , ** kwargs )
906
+ return self ._iop ("min" , xp .minimum , y , copy = copy , xp = xp , ** kwargs )
831
907
832
- def max (self , y : Array , / , ** kwargs : Untyped ) -> Array :
908
+ def max (
909
+ self ,
910
+ y : Array ,
911
+ / ,
912
+ copy : bool | None = True ,
913
+ xp : ModuleType | None = None ,
914
+ ** kwargs : Untyped ,
915
+ ) -> Array :
833
916
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array"""
834
- xp = array_namespace (self .x )
917
+ if xp is None :
918
+ xp = array_namespace (self .x )
835
919
y = xp .asarray (y )
836
- return self ._iop ("max" , xp .maximum , y , ** kwargs )
920
+ return self ._iop ("max" , xp .maximum , y , copy = copy , xp = xp , ** kwargs )
0 commit comments