Skip to content

Commit c9d958e

Browse files
committed
Refactor dtype_helpers
1 parent 8c5fad0 commit c9d958e

File tree

1 file changed

+45
-85
lines changed

1 file changed

+45
-85
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 45 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,68 @@
22

33

44
__all__ = [
5-
"promotion_table",
6-
"dtype_nbits",
7-
"dtype_signed",
8-
"input_types",
9-
"dtypes_to_scalars",
10-
"elementwise_function_input_types",
11-
"elementwise_function_output_types",
12-
"binary_operators",
13-
"unary_operators",
14-
"operators_to_functions",
5+
'dtypes_to_scalars',
6+
'input_types',
7+
'promotion_table',
8+
'dtype_nbits',
9+
'dtype_signed',
10+
'binary_operators',
11+
'unary_operators',
12+
'operators_to_functions',
13+
'elementwise_function_input_types',
14+
'elementwise_function_output_types',
1515
]
1616

1717

18-
dtype_nbits = {
19-
**{d: 8 for d in [xp.int8, xp.uint8]},
20-
**{d: 16 for d in [xp.int16, xp.uint16]},
21-
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
22-
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
18+
int_dtypes = (xp.int8, xp.int16, xp.int32, xp.int64)
19+
uint_dtypes = (xp.uint8, xp.uint16, xp.uint32, xp.uint64)
20+
all_int_dtypes = int_dtypes + uint_dtypes
21+
float_dtypes = (xp.float32, xp.float64)
22+
numeric_dtypes = all_int_dtypes + float_dtypes
23+
all_dtypes = (xp.bool,) + numeric_dtypes
24+
25+
26+
dtypes_to_scalars = {
27+
xp.bool: [bool],
28+
**{d: [int] for d in all_int_dtypes},
29+
**{d: [int, float] for d in float_dtypes},
2330
}
2431

2532

26-
dtype_signed = {
27-
**{d: True for d in [xp.int8, xp.int16, xp.int32, xp.int64]},
28-
**{d: False for d in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]},
33+
input_types = {
34+
'any': all_dtypes,
35+
'boolean': (xp.bool,),
36+
'floating': float_dtypes,
37+
'integer': all_int_dtypes,
38+
'integer_or_boolean': (xp.bool,) + uint_dtypes + int_dtypes,
39+
'numeric': numeric_dtypes,
2940
}
3041

3142

32-
signed_integer_promotion_table = {
43+
_numeric_promotions = {
44+
# ints
3345
(xp.int8, xp.int8): xp.int8,
3446
(xp.int8, xp.int16): xp.int16,
3547
(xp.int8, xp.int32): xp.int32,
3648
(xp.int8, xp.int64): xp.int64,
37-
(xp.int16, xp.int8): xp.int16,
3849
(xp.int16, xp.int16): xp.int16,
3950
(xp.int16, xp.int32): xp.int32,
4051
(xp.int16, xp.int64): xp.int64,
41-
(xp.int32, xp.int8): xp.int32,
42-
(xp.int32, xp.int16): xp.int32,
4352
(xp.int32, xp.int32): xp.int32,
4453
(xp.int32, xp.int64): xp.int64,
45-
(xp.int64, xp.int8): xp.int64,
46-
(xp.int64, xp.int16): xp.int64,
47-
(xp.int64, xp.int32): xp.int64,
4854
(xp.int64, xp.int64): xp.int64,
49-
}
50-
51-
52-
unsigned_integer_promotion_table = {
55+
# uints
5356
(xp.uint8, xp.uint8): xp.uint8,
5457
(xp.uint8, xp.uint16): xp.uint16,
5558
(xp.uint8, xp.uint32): xp.uint32,
5659
(xp.uint8, xp.uint64): xp.uint64,
57-
(xp.uint16, xp.uint8): xp.uint16,
5860
(xp.uint16, xp.uint16): xp.uint16,
5961
(xp.uint16, xp.uint32): xp.uint32,
6062
(xp.uint16, xp.uint64): xp.uint64,
61-
(xp.uint32, xp.uint8): xp.uint32,
62-
(xp.uint32, xp.uint16): xp.uint32,
6363
(xp.uint32, xp.uint32): xp.uint32,
6464
(xp.uint32, xp.uint64): xp.uint64,
65-
(xp.uint64, xp.uint8): xp.uint64,
66-
(xp.uint64, xp.uint16): xp.uint64,
67-
(xp.uint64, xp.uint32): xp.uint64,
6865
(xp.uint64, xp.uint64): xp.uint64,
69-
}
70-
71-
72-
mixed_signed_unsigned_promotion_table = {
66+
# ints and uints (mixed sign)
7367
(xp.int8, xp.uint8): xp.int16,
7468
(xp.int8, xp.uint16): xp.int32,
7569
(xp.int8, xp.uint32): xp.int64,
@@ -82,63 +76,29 @@
8276
(xp.int64, xp.uint8): xp.int64,
8377
(xp.int64, xp.uint16): xp.int64,
8478
(xp.int64, xp.uint32): xp.int64,
85-
}
86-
87-
88-
flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()}
89-
90-
91-
float_promotion_table = {
79+
# floats
9280
(xp.float32, xp.float32): xp.float32,
9381
(xp.float32, xp.float64): xp.float64,
94-
(xp.float64, xp.float32): xp.float64,
9582
(xp.float64, xp.float64): xp.float64,
9683
}
97-
98-
99-
boolean_promotion_table = {
100-
(xp.bool, xp.bool): xp.bool,
101-
}
102-
103-
10484
promotion_table = {
105-
**signed_integer_promotion_table,
106-
**unsigned_integer_promotion_table,
107-
**mixed_signed_unsigned_promotion_table,
108-
**flipped_mixed_signed_unsigned_promotion_table,
109-
**float_promotion_table,
110-
**boolean_promotion_table,
85+
(xp.bool, xp.bool): xp.bool,
86+
**_numeric_promotions,
87+
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
11188
}
11289

11390

114-
input_types = {
115-
'any': sorted(set(promotion_table.values())),
116-
'boolean': sorted(set(boolean_promotion_table.values())),
117-
'floating': sorted(set(float_promotion_table.values())),
118-
'integer': sorted(set({**signed_integer_promotion_table,
119-
**unsigned_integer_promotion_table}.values())),
120-
'integer_or_boolean': sorted(set({**signed_integer_promotion_table,
121-
**unsigned_integer_promotion_table,
122-
**boolean_promotion_table}.values())),
123-
'numeric': sorted(set({**float_promotion_table,
124-
**signed_integer_promotion_table,
125-
**unsigned_integer_promotion_table}.values())),
91+
dtype_nbits = {
92+
**{d: 8 for d in [xp.int8, xp.uint8]},
93+
**{d: 16 for d in [xp.int16, xp.uint16]},
94+
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
95+
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
12696
}
12797

12898

129-
dtypes_to_scalars = {
130-
xp.bool: [bool],
131-
xp.int8: [int],
132-
xp.int16: [int],
133-
xp.int32: [int],
134-
xp.int64: [int],
135-
# Note: unsigned int dtypes only correspond to positive integers
136-
xp.uint8: [int],
137-
xp.uint16: [int],
138-
xp.uint32: [int],
139-
xp.uint64: [int],
140-
xp.float32: [int, float],
141-
xp.float64: [int, float],
99+
dtype_signed = {
100+
**{d: True for d in int_dtypes},
101+
**{d: False for d in uint_dtypes},
142102
}
143103

144104

0 commit comments

Comments
 (0)