1
1
import math
2
2
from itertools import product
3
- from typing import Sequence , Union
3
+ from typing import Sequence , Union , get_args
4
4
5
5
import pytest
6
6
from hypothesis import assume , given , note
@@ -33,11 +33,9 @@ def test_getitem(shape, data):
33
33
size = math .prod (shape )
34
34
dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
35
35
obj = data .draw (
36
- st .lists (
37
- xps .from_dtype (dtype ),
38
- min_size = size ,
39
- max_size = size ,
40
- ).map (lambda l : reshape (l , shape )),
36
+ st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
37
+ lambda l : reshape (l , shape )
38
+ ),
41
39
label = "obj" ,
42
40
)
43
41
x = xp .asarray (obj , dtype = dtype )
@@ -47,7 +45,6 @@ def test_getitem(shape, data):
47
45
out = x [key ]
48
46
49
47
ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
50
-
51
48
_key = tuple (key ) if isinstance (key , tuple ) else (key ,)
52
49
if Ellipsis in _key :
53
50
start_a = _key .index (Ellipsis )
@@ -78,7 +75,39 @@ def test_getitem(shape, data):
78
75
ph .assert_array ("__getitem__" , out , expected )
79
76
80
77
81
- # TODO: test_setitem
78
+ @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
79
+ def test_setitem (shape , data ):
80
+ size = math .prod (shape )
81
+ dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
82
+ obj = data .draw (
83
+ st .lists (xps .from_dtype (dtype ), min_size = size , max_size = size ).map (
84
+ lambda l : reshape (l , shape )
85
+ ),
86
+ label = "obj" ,
87
+ )
88
+ x = xp .asarray (obj , dtype = dtype )
89
+ note (f"{ x = } " )
90
+ key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
91
+ value = data .draw (
92
+ xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
93
+ )
94
+
95
+ res = xp .asarray (x , copy = True )
96
+ res [key ] = value
97
+
98
+ ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
99
+ ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
100
+ if isinstance (value , get_args (Scalar )):
101
+ msg = f"x[{ key } ]={ res [key ]!r} , but should be { value = } [__setitem__()]"
102
+ if math .isnan (value ):
103
+ assert xp .isnan (res [key ]), msg
104
+ else :
105
+ assert res [key ] == value , msg
106
+ else :
107
+ ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
108
+
109
+
110
+ # TODO: test boolean indexing
82
111
83
112
84
113
def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments