|
47 | 47 | shared_dtypes = shared(dtypes, key="dtype")
|
48 | 48 | shared_floating_dtypes = shared(floating_dtypes, key="dtype")
|
49 | 49 |
|
| 50 | +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] |
| 51 | +_sorted_dtypes = [d for category in _dtype_categories for d in category] |
| 52 | + |
| 53 | +def _dtypes_sorter(dtype_pair): |
| 54 | + dtype1, dtype2 = dtype_pair |
| 55 | + if dtype1 == dtype2: |
| 56 | + return _sorted_dtypes.index(dtype1) |
| 57 | + key = len(_sorted_dtypes) |
| 58 | + rank1 = _sorted_dtypes.index(dtype1) |
| 59 | + rank2 = _sorted_dtypes.index(dtype2) |
| 60 | + for category in _dtype_categories: |
| 61 | + if dtype1 in category and dtype2 in category: |
| 62 | + break |
| 63 | + else: |
| 64 | + key += len(_sorted_dtypes) ** 2 |
| 65 | + key += 2 * (rank1 + rank2) |
| 66 | + if rank1 > rank2: |
| 67 | + key += 1 |
| 68 | + return key |
| 69 | + |
| 70 | +promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter) |
50 | 71 |
|
51 |
| -sorted_table = sorted(dh.promotion_table) |
52 |
| -sorted_table = sorted( |
53 |
| - sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) |
54 |
| -) |
55 | 72 | if FILTER_UNDEFINED_DTYPES:
|
56 |
| - sorted_table = [(i, j) for i, j in sorted_table |
57 |
| - if not isinstance(i, _UndefinedStub) |
58 |
| - and not isinstance(j, _UndefinedStub)] |
| 73 | + promotable_dtypes = [ |
| 74 | + (i, j) for i, j in promotable_dtypes |
| 75 | + if not isinstance(i, _UndefinedStub) |
| 76 | + and not isinstance(j, _UndefinedStub) |
| 77 | + ] |
59 | 78 |
|
60 | 79 |
|
61 | 80 | def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes):
|
62 |
| - # sort for shrinking (sampled_from shrinks to the earlier elements in the |
63 |
| - # list). Give pairs of the same dtypes first, then smaller dtypes, |
64 |
| - # preferring float, then int, then unsigned int. Note, this might not |
65 |
| - # always result in examples shrinking to these pairs because strategies |
66 |
| - # that draw from dtypes might not draw the same example from different |
67 |
| - # pairs (XXX: Can we redesign the strategies so that they can prefer |
68 |
| - # shrinking dtypes over values?) |
69 | 81 | return sampled_from(
|
70 |
| - [(i, j) for i, j in sorted_table if i in dtype_objs and j in dtype_objs] |
| 82 | + [(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs] |
71 | 83 | )
|
72 | 84 |
|
73 | 85 | # shared() allows us to draw either the function or the function name and they
|
|
0 commit comments