10
10
from . import dtype_helpers as dh
11
11
from . import hypothesis_helpers as hh
12
12
from . import pytest_helpers as ph
13
+ from . import shape_helpers as sh
13
14
from . import xps
14
15
from .typing import DataType , Param , Scalar , ScalarType , Shape
15
16
@@ -87,6 +88,7 @@ def test_setitem(shape, data):
87
88
)
88
89
x = xp .asarray (obj , dtype = dtype )
89
90
note (f"{ x = } " )
91
+ # TODO: test setting non-0d arrays
90
92
key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
91
93
value = data .draw (
92
94
xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
@@ -104,10 +106,100 @@ def test_setitem(shape, data):
104
106
else :
105
107
assert res [key ] == value , msg
106
108
else :
107
- ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
109
+ ph .assert_0d_equals (
110
+ "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
111
+ )
112
+ _key = key if isinstance (key , tuple ) else (key ,)
113
+ assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
114
+ _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
115
+ unaffected_indices = list (sh .ndindex (res .shape ))
116
+ unaffected_indices .remove (_key )
117
+ for idx in unaffected_indices :
118
+ ph .assert_0d_equals (
119
+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
120
+ )
121
+
122
+
123
+ # TODO: make mask tests optional
124
+
125
+
126
+ @given (hh .shapes (), st .data ())
127
+ def test_getitem_masking (shape , data ):
128
+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
129
+ mask_shapes = st .one_of (
130
+ st .sampled_from ([x .shape , ()]),
131
+ st .lists (st .booleans (), min_size = x .ndim , max_size = x .ndim ).map (
132
+ lambda l : tuple (s if b else 0 for s , b in zip (x .shape , l ))
133
+ ),
134
+ hh .shapes (),
135
+ )
136
+ key = data .draw (xps .arrays (dtype = xp .bool , shape = mask_shapes ), label = "key" )
108
137
138
+ if key .ndim > x .ndim or not all (
139
+ ks in (xs , 0 ) for xs , ks in zip (x .shape , key .shape )
140
+ ):
141
+ with pytest .raises (IndexError ):
142
+ x [key ]
143
+ return
144
+
145
+ out = x [key ]
109
146
110
- # TODO: test boolean indexing
147
+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
148
+ if key .ndim == 0 :
149
+ out_shape = (1 ,) if key else (0 ,)
150
+ out_shape += x .shape
151
+ else :
152
+ size = int (xp .sum (xp .astype (key , xp .uint8 )))
153
+ out_shape = (size ,) + x .shape [key .ndim :]
154
+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
155
+ if not any (s == 0 for s in key .shape ):
156
+ assume (key .ndim == x .ndim ) # TODO: test key.ndim < x.ndim scenarios
157
+ out_indices = sh .ndindex (out .shape )
158
+ for x_idx in sh .ndindex (x .shape ):
159
+ if key [x_idx ]:
160
+ out_idx = next (out_indices )
161
+ ph .assert_0d_equals (
162
+ "__getitem__" ,
163
+ f"x[{ x_idx } ]" ,
164
+ x [x_idx ],
165
+ f"out[{ out_idx } ]" ,
166
+ out [out_idx ],
167
+ )
168
+
169
+
170
+ @given (hh .shapes (), st .data ())
171
+ def test_setitem_masking (shape , data ):
172
+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
173
+ key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
174
+ value = data .draw (
175
+ xps .from_dtype (x .dtype ) | xps .arrays (dtype = x .dtype , shape = ()), label = "value"
176
+ )
177
+
178
+ res = xp .asarray (x , copy = True )
179
+ res [key ] = value
180
+
181
+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
182
+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.dtype" )
183
+ scalar_type = dh .get_scalar_type (x .dtype )
184
+ for idx in sh .ndindex (x .shape ):
185
+ if key [idx ]:
186
+ if isinstance (value , Scalar ):
187
+ ph .assert_scalar_equals (
188
+ "__setitem__" ,
189
+ scalar_type ,
190
+ idx ,
191
+ scalar_type (res [idx ]),
192
+ value ,
193
+ repr_name = "modified x" ,
194
+ )
195
+ else :
196
+ ph .assert_0d_equals (
197
+ "__setitem__" , "value" , value , f"modified x[{ idx } ]" , res [idx ]
198
+ )
199
+ else :
200
+ ph .assert_0d_equals (
201
+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
202
+ )
111
203
112
204
113
205
def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments