|
2 | 2 |
|
3 | 3 |
|
4 | 4 | __all__ = [
|
5 |
| - "promotion_table", |
6 |
| - "dtype_nbits", |
7 |
| - "dtype_signed", |
8 |
| - "input_types", |
9 |
| - "dtypes_to_scalars", |
10 |
| - "elementwise_function_input_types", |
11 |
| - "elementwise_function_output_types", |
12 |
| - "binary_operators", |
13 |
| - "unary_operators", |
14 |
| - "operators_to_functions", |
| 5 | + 'dtypes_to_scalars', |
| 6 | + 'input_types', |
| 7 | + 'promotion_table', |
| 8 | + 'dtype_nbits', |
| 9 | + 'dtype_signed', |
| 10 | + 'binary_operators', |
| 11 | + 'unary_operators', |
| 12 | + 'operators_to_functions', |
| 13 | + 'elementwise_function_input_types', |
| 14 | + 'elementwise_function_output_types', |
15 | 15 | ]
|
16 | 16 |
|
17 | 17 |
|
18 |
| -dtype_nbits = { |
19 |
| - **{d: 8 for d in [xp.int8, xp.uint8]}, |
20 |
| - **{d: 16 for d in [xp.int16, xp.uint16]}, |
21 |
| - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
22 |
| - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
| 18 | +int_dtypes = (xp.int8, xp.int16, xp.int32, xp.int64) |
| 19 | +uint_dtypes = (xp.uint8, xp.uint16, xp.uint32, xp.uint64) |
| 20 | +all_int_dtypes = int_dtypes + uint_dtypes |
| 21 | +float_dtypes = (xp.float32, xp.float64) |
| 22 | +numeric_dtypes = all_int_dtypes + float_dtypes |
| 23 | +all_dtypes = (xp.bool,) + numeric_dtypes |
| 24 | + |
| 25 | + |
| 26 | +dtypes_to_scalars = { |
| 27 | + xp.bool: [bool], |
| 28 | + **{d: [int] for d in all_int_dtypes}, |
| 29 | + **{d: [int, float] for d in float_dtypes}, |
23 | 30 | }
|
24 | 31 |
|
25 | 32 |
|
26 |
| -dtype_signed = { |
27 |
| - **{d: True for d in [xp.int8, xp.int16, xp.int32, xp.int64]}, |
28 |
| - **{d: False for d in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]}, |
| 33 | +input_types = { |
| 34 | + 'any': all_dtypes, |
| 35 | + 'boolean': (xp.bool,), |
| 36 | + 'floating': float_dtypes, |
| 37 | + 'integer': all_int_dtypes, |
| 38 | + 'integer_or_boolean': (xp.bool,) + uint_dtypes + int_dtypes, |
| 39 | + 'numeric': numeric_dtypes, |
29 | 40 | }
|
30 | 41 |
|
31 | 42 |
|
32 |
| -signed_integer_promotion_table = { |
| 43 | +_numeric_promotions = { |
| 44 | + # ints |
33 | 45 | (xp.int8, xp.int8): xp.int8,
|
34 | 46 | (xp.int8, xp.int16): xp.int16,
|
35 | 47 | (xp.int8, xp.int32): xp.int32,
|
36 | 48 | (xp.int8, xp.int64): xp.int64,
|
37 |
| - (xp.int16, xp.int8): xp.int16, |
38 | 49 | (xp.int16, xp.int16): xp.int16,
|
39 | 50 | (xp.int16, xp.int32): xp.int32,
|
40 | 51 | (xp.int16, xp.int64): xp.int64,
|
41 |
| - (xp.int32, xp.int8): xp.int32, |
42 |
| - (xp.int32, xp.int16): xp.int32, |
43 | 52 | (xp.int32, xp.int32): xp.int32,
|
44 | 53 | (xp.int32, xp.int64): xp.int64,
|
45 |
| - (xp.int64, xp.int8): xp.int64, |
46 |
| - (xp.int64, xp.int16): xp.int64, |
47 |
| - (xp.int64, xp.int32): xp.int64, |
48 | 54 | (xp.int64, xp.int64): xp.int64,
|
49 |
| -} |
50 |
| - |
51 |
| - |
52 |
| -unsigned_integer_promotion_table = { |
| 55 | + # uints |
53 | 56 | (xp.uint8, xp.uint8): xp.uint8,
|
54 | 57 | (xp.uint8, xp.uint16): xp.uint16,
|
55 | 58 | (xp.uint8, xp.uint32): xp.uint32,
|
56 | 59 | (xp.uint8, xp.uint64): xp.uint64,
|
57 |
| - (xp.uint16, xp.uint8): xp.uint16, |
58 | 60 | (xp.uint16, xp.uint16): xp.uint16,
|
59 | 61 | (xp.uint16, xp.uint32): xp.uint32,
|
60 | 62 | (xp.uint16, xp.uint64): xp.uint64,
|
61 |
| - (xp.uint32, xp.uint8): xp.uint32, |
62 |
| - (xp.uint32, xp.uint16): xp.uint32, |
63 | 63 | (xp.uint32, xp.uint32): xp.uint32,
|
64 | 64 | (xp.uint32, xp.uint64): xp.uint64,
|
65 |
| - (xp.uint64, xp.uint8): xp.uint64, |
66 |
| - (xp.uint64, xp.uint16): xp.uint64, |
67 |
| - (xp.uint64, xp.uint32): xp.uint64, |
68 | 65 | (xp.uint64, xp.uint64): xp.uint64,
|
69 |
| -} |
70 |
| - |
71 |
| - |
72 |
| -mixed_signed_unsigned_promotion_table = { |
| 66 | + # ints and uints (mixed sign) |
73 | 67 | (xp.int8, xp.uint8): xp.int16,
|
74 | 68 | (xp.int8, xp.uint16): xp.int32,
|
75 | 69 | (xp.int8, xp.uint32): xp.int64,
|
|
82 | 76 | (xp.int64, xp.uint8): xp.int64,
|
83 | 77 | (xp.int64, xp.uint16): xp.int64,
|
84 | 78 | (xp.int64, xp.uint32): xp.int64,
|
85 |
| -} |
86 |
| - |
87 |
| - |
88 |
| -flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} |
89 |
| - |
90 |
| - |
91 |
| -float_promotion_table = { |
| 79 | + # floats |
92 | 80 | (xp.float32, xp.float32): xp.float32,
|
93 | 81 | (xp.float32, xp.float64): xp.float64,
|
94 |
| - (xp.float64, xp.float32): xp.float64, |
95 | 82 | (xp.float64, xp.float64): xp.float64,
|
96 | 83 | }
|
97 |
| - |
98 |
| - |
99 |
| -boolean_promotion_table = { |
100 |
| - (xp.bool, xp.bool): xp.bool, |
101 |
| -} |
102 |
| - |
103 |
| - |
104 | 84 | promotion_table = {
|
105 |
| - **signed_integer_promotion_table, |
106 |
| - **unsigned_integer_promotion_table, |
107 |
| - **mixed_signed_unsigned_promotion_table, |
108 |
| - **flipped_mixed_signed_unsigned_promotion_table, |
109 |
| - **float_promotion_table, |
110 |
| - **boolean_promotion_table, |
| 85 | + (xp.bool, xp.bool): xp.bool, |
| 86 | + **_numeric_promotions, |
| 87 | + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, |
111 | 88 | }
|
112 | 89 |
|
113 | 90 |
|
114 |
| -input_types = { |
115 |
| - 'any': sorted(set(promotion_table.values())), |
116 |
| - 'boolean': sorted(set(boolean_promotion_table.values())), |
117 |
| - 'floating': sorted(set(float_promotion_table.values())), |
118 |
| - 'integer': sorted(set({**signed_integer_promotion_table, |
119 |
| - **unsigned_integer_promotion_table}.values())), |
120 |
| - 'integer_or_boolean': sorted(set({**signed_integer_promotion_table, |
121 |
| - **unsigned_integer_promotion_table, |
122 |
| - **boolean_promotion_table}.values())), |
123 |
| - 'numeric': sorted(set({**float_promotion_table, |
124 |
| - **signed_integer_promotion_table, |
125 |
| - **unsigned_integer_promotion_table}.values())), |
| 91 | +dtype_nbits = { |
| 92 | + **{d: 8 for d in [xp.int8, xp.uint8]}, |
| 93 | + **{d: 16 for d in [xp.int16, xp.uint16]}, |
| 94 | + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, |
| 95 | + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, |
126 | 96 | }
|
127 | 97 |
|
128 | 98 |
|
129 |
| -dtypes_to_scalars = { |
130 |
| - xp.bool: [bool], |
131 |
| - xp.int8: [int], |
132 |
| - xp.int16: [int], |
133 |
| - xp.int32: [int], |
134 |
| - xp.int64: [int], |
135 |
| - # Note: unsigned int dtypes only correspond to positive integers |
136 |
| - xp.uint8: [int], |
137 |
| - xp.uint16: [int], |
138 |
| - xp.uint32: [int], |
139 |
| - xp.uint64: [int], |
140 |
| - xp.float32: [int, float], |
141 |
| - xp.float64: [int, float], |
| 99 | +dtype_signed = { |
| 100 | + **{d: True for d in int_dtypes}, |
| 101 | + **{d: False for d in uint_dtypes}, |
142 | 102 | }
|
143 | 103 |
|
144 | 104 |
|
|
0 commit comments