1
+ import math
2
+ from typing import Set
3
+
1
4
from hypothesis import given
2
5
from hypothesis import strategies as st
3
6
from hypothesis .control import assume
4
7
8
+ from xptests .typing import Scalar , ScalarType , Shape
9
+
5
10
from . import _array_module as xp
6
11
from . import dtype_helpers as dh
7
12
from . import hypothesis_helpers as hh
10
15
from . import xps
11
16
12
17
18
+ def assert_scalar_in_set (
19
+ func_name : str ,
20
+ type_ : ScalarType ,
21
+ idx : Shape ,
22
+ out : Scalar ,
23
+ set_ : Set [Scalar ],
24
+ / ,
25
+ ** kw ,
26
+ ):
27
+ out_repr = "out" if idx == () else f"out[{ idx } ]"
28
+ if math .isnan (out ):
29
+ raise NotImplementedError ()
30
+ msg = f"{ out_repr } ={ out } , but should be in { set_ } [{ func_name } ({ ph .fmt_kw (kw )} )]"
31
+ assert out in set_ , msg
32
+
33
+
13
34
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
14
35
@given (
15
36
x = xps .arrays (
@@ -34,20 +55,39 @@ def test_argsort(x, data):
34
55
35
56
out = xp .argsort (x , ** kw )
36
57
37
- ph .assert_default_index ("sort " , out .dtype )
38
- ph .assert_shape ("sort " , out .shape , x .shape , ** kw )
58
+ ph .assert_default_index ("argsort " , out .dtype )
59
+ ph .assert_shape ("argsort " , out .shape , x .shape , ** kw )
39
60
axis = kw .get ("axis" , - 1 )
40
61
axes = sh .normalise_axis (axis , x .ndim )
41
- descending = kw .get ("descending" , False )
42
62
scalar_type = dh .get_scalar_type (x .dtype )
43
63
for indices in sh .axes_ndindex (x .shape , axes ):
44
64
elements = [scalar_type (x [idx ]) for idx in indices ]
45
- indices_order = sorted (range (len (indices )), key = elements .__getitem__ )
46
- if descending :
47
- # sorted(..., reverse=descending) doesn't always work
48
- indices_order = reversed (indices_order )
49
- for idx , o in zip (indices , indices_order ):
50
- ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
65
+ orders = sorted (range (len (elements )), key = elements .__getitem__ )
66
+ if kw .get ("descending" , False ):
67
+ orders = reversed (orders )
68
+ if kw .get ("stable" , True ):
69
+ for idx , o in zip (indices , orders ):
70
+ ph .assert_scalar_equals ("argsort" , int , idx , int (out [idx ]), o )
71
+ else :
72
+ idx_elements = dict (zip (indices , elements ))
73
+ idx_orders = dict (zip (indices , orders ))
74
+ element_orders = {}
75
+ for e in set (elements ):
76
+ element_orders [e ] = [
77
+ idx_orders [idx ] for idx in indices if idx_elements [idx ] == e
78
+ ]
79
+ for idx , e in zip (indices , elements ):
80
+ o = int (out [idx ])
81
+ expected_orders = element_orders [e ]
82
+ if len (expected_orders ) == 1 :
83
+ expected_order = expected_orders [0 ]
84
+ ph .assert_scalar_equals (
85
+ "argsort" , int , idx , o , expected_order , ** kw
86
+ )
87
+ else :
88
+ assert_scalar_in_set (
89
+ "argsort" , int , idx , o , set (expected_orders ), ** kw
90
+ )
51
91
52
92
53
93
# TODO: Test with signed zeros and NaNs (and ignore them somehow)
@@ -78,15 +118,15 @@ def test_sort(x, data):
78
118
ph .assert_shape ("sort" , out .shape , x .shape , ** kw )
79
119
axis = kw .get ("axis" , - 1 )
80
120
axes = sh .normalise_axis (axis , x .ndim )
81
- descending = kw .get ("descending" , False )
82
121
scalar_type = dh .get_scalar_type (x .dtype )
83
122
for indices in sh .axes_ndindex (x .shape , axes ):
84
123
elements = [scalar_type (x [idx ]) for idx in indices ]
85
- indices_order = sorted (
86
- range (len (indices )), key = elements .__getitem__ , reverse = descending
124
+ size = len (elements )
125
+ orders = sorted (
126
+ range (size ), key = elements .__getitem__ , reverse = kw .get ("descending" , False )
87
127
)
88
- x_indices = [ indices [ o ] for o in indices_order ]
89
- for out_idx , x_idx in zip ( indices , x_indices ):
128
+ for out_idx , o in zip ( indices , orders ):
129
+ x_idx = indices [ o ]
90
130
ph .assert_0d_equals (
91
131
"sort" ,
92
132
f"x[{ x_idx } ]" ,
0 commit comments