Skip to content

Commit d4467f1

Browse files
committed
Sort promotable_dtypes without relying on dtype comparisons
1 parent 75f44d7 commit d4467f1

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,39 @@
4747
shared_dtypes = shared(dtypes, key="dtype")
4848
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
4949

50+
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
51+
_sorted_dtypes = [d for category in _dtype_categories for d in category]
52+
53+
def _dtypes_sorter(dtype_pair):
54+
dtype1, dtype2 = dtype_pair
55+
if dtype1 == dtype2:
56+
return _sorted_dtypes.index(dtype1)
57+
key = len(_sorted_dtypes)
58+
rank1 = _sorted_dtypes.index(dtype1)
59+
rank2 = _sorted_dtypes.index(dtype2)
60+
for category in _dtype_categories:
61+
if dtype1 in category and dtype2 in category:
62+
break
63+
else:
64+
key += len(_sorted_dtypes) ** 2
65+
key += 2 * (rank1 + rank2)
66+
if rank1 > rank2:
67+
key += 1
68+
return key
69+
70+
promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
5071

51-
sorted_table = sorted(dh.promotion_table)
52-
sorted_table = sorted(
53-
sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij)
54-
)
5572
if FILTER_UNDEFINED_DTYPES:
56-
sorted_table = [(i, j) for i, j in sorted_table
57-
if not isinstance(i, _UndefinedStub)
58-
and not isinstance(j, _UndefinedStub)]
73+
promotable_dtypes = [
74+
(i, j) for i, j in promotable_dtypes
75+
if not isinstance(i, _UndefinedStub)
76+
and not isinstance(j, _UndefinedStub)
77+
]
5978

6079

6180
def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes):
62-
# sort for shrinking (sampled_from shrinks to the earlier elements in the
63-
# list). Give pairs of the same dtypes first, then smaller dtypes,
64-
# preferring float, then int, then unsigned int. Note, this might not
65-
# always result in examples shrinking to these pairs because strategies
66-
# that draw from dtypes might not draw the same example from different
67-
# pairs (XXX: Can we redesign the strategies so that they can prefer
68-
# shrinking dtypes over values?)
6981
return sampled_from(
70-
[(i, j) for i, j in sorted_table if i in dtype_objs and j in dtype_objs]
82+
[(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs]
7183
)
7284

7385
# shared() allows us to draw either the function or the function name and they

0 commit comments

Comments
 (0)