@@ -46,29 +46,35 @@ def test_getitem(shape, data):
46
46
47
47
out = x [key ]
48
48
49
+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
50
+
49
51
_key = tuple (key ) if isinstance (key , tuple ) else (key ,)
50
52
if Ellipsis in _key :
51
53
start_a = _key .index (Ellipsis )
52
54
stop_a = start_a + (len (shape ) - (len (_key ) - 1 ))
53
55
slices = tuple (slice (None , None ) for _ in range (start_a , stop_a ))
54
56
_key = _key [:start_a ] + slices + _key [start_a + 1 :]
55
57
axes_indices = []
58
+ out_shape = []
56
59
for a , i in enumerate (_key ):
57
60
if isinstance (i , int ):
58
61
axes_indices .append ([i ])
59
62
else :
60
63
side = shape [a ]
61
64
indices = range (side )[i ]
62
- assume (len (indices ) > 0 ) # TODO: test 0-sided arrays
63
65
axes_indices .append (indices )
64
- expected = []
66
+ out_shape .append (len (indices ))
67
+ out_shape = tuple (out_shape )
68
+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
69
+ assume (all (len (indices ) > 0 for indices in axes_indices ))
70
+ out_obj = []
65
71
for idx in product (* axes_indices ):
66
72
val = obj
67
73
for i in idx :
68
74
val = val [i ]
69
- expected .append (val )
70
- expected = reshape (expected , out . shape )
71
- expected = xp .asarray (expected , dtype = dtype )
75
+ out_obj .append (val )
76
+ out_obj = reshape (out_obj , out_shape )
77
+ expected = xp .asarray (out_obj , dtype = dtype )
72
78
ph .assert_array ("__getitem__" , out , expected )
73
79
74
80
0 commit comments