Skip to content

Commit d4f8dea

Browse files
committed
Refactor things from test_type_promotion.py to dtype_helpers.py
1 parent 2e6f5a5 commit d4f8dea

File tree

8 files changed

+375
-356
lines changed

8 files changed

+375
-356
lines changed

array_api_tests/array_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
_numeric_dtypes, _boolean_dtypes, _dtypes,
1010
asarray)
1111
from . import _array_module
12+
from .dtype_helpers import dtype_mapping, promotion_table
1213

1314
# These are exported here so that they can be included in the special cases
1415
# tests from this file.
1516
from ._array_module import logical_not, subtract, floor, ceil, where
1617

18+
1719
__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
1820
'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
1921
'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
@@ -369,8 +371,6 @@ def promote_dtypes(dtype1, dtype2):
369371
Special case of result_type() which uses the exact type promotion table
370372
from the spec.
371373
"""
372-
from .test_type_promotion import dtype_mapping, promotion_table
373-
374374
# Equivalent to this, but some libraries may not work properly with using
375375
# dtype objects as dict keys
376376
#

array_api_tests/dtype_helpers.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
from . import _array_module as xp
2+
3+
__all__ = [
4+
"dtype_mapping",
5+
"promotion_table",
6+
"dtype_nbits",
7+
"dtype_signed",
8+
"input_types",
9+
"dtypes_to_scalars",
10+
"elementwise_function_input_types",
11+
"elementwise_function_output_types",
12+
"binary_operators",
13+
"unary_operators",
14+
"operators_to_functions",
15+
]
16+
17+
dtype_mapping = {
18+
'i1': xp.int8,
19+
'i2': xp.int16,
20+
'i4': xp.int32,
21+
'i8': xp.int64,
22+
'u1': xp.uint8,
23+
'u2': xp.uint16,
24+
'u4': xp.uint32,
25+
'u8': xp.uint64,
26+
'f4': xp.float32,
27+
'f8': xp.float64,
28+
'b': xp.bool,
29+
}
30+
31+
reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()}
32+
33+
def dtype_nbits(dtype):
34+
if dtype == xp.int8:
35+
return 8
36+
elif dtype == xp.int16:
37+
return 16
38+
elif dtype == xp.int32:
39+
return 32
40+
elif dtype == xp.int64:
41+
return 64
42+
elif dtype == xp.uint8:
43+
return 8
44+
elif dtype == xp.uint16:
45+
return 16
46+
elif dtype == xp.uint32:
47+
return 32
48+
elif dtype == xp.uint64:
49+
return 64
50+
elif dtype == xp.float32:
51+
return 32
52+
elif dtype == xp.float64:
53+
return 64
54+
else:
55+
raise ValueError(f"dtype_nbits is not defined for {dtype}")
56+
57+
def dtype_signed(dtype):
58+
if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]:
59+
return True
60+
elif dtype in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]:
61+
return False
62+
raise ValueError("dtype_signed is only defined for integer dtypes")
63+
64+
signed_integer_promotion_table = {
65+
('i1', 'i1'): 'i1',
66+
('i1', 'i2'): 'i2',
67+
('i1', 'i4'): 'i4',
68+
('i1', 'i8'): 'i8',
69+
('i2', 'i1'): 'i2',
70+
('i2', 'i2'): 'i2',
71+
('i2', 'i4'): 'i4',
72+
('i2', 'i8'): 'i8',
73+
('i4', 'i1'): 'i4',
74+
('i4', 'i2'): 'i4',
75+
('i4', 'i4'): 'i4',
76+
('i4', 'i8'): 'i8',
77+
('i8', 'i1'): 'i8',
78+
('i8', 'i2'): 'i8',
79+
('i8', 'i4'): 'i8',
80+
('i8', 'i8'): 'i8',
81+
}
82+
83+
unsigned_integer_promotion_table = {
84+
('u1', 'u1'): 'u1',
85+
('u1', 'u2'): 'u2',
86+
('u1', 'u4'): 'u4',
87+
('u1', 'u8'): 'u8',
88+
('u2', 'u1'): 'u2',
89+
('u2', 'u2'): 'u2',
90+
('u2', 'u4'): 'u4',
91+
('u2', 'u8'): 'u8',
92+
('u4', 'u1'): 'u4',
93+
('u4', 'u2'): 'u4',
94+
('u4', 'u4'): 'u4',
95+
('u4', 'u8'): 'u8',
96+
('u8', 'u1'): 'u8',
97+
('u8', 'u2'): 'u8',
98+
('u8', 'u4'): 'u8',
99+
('u8', 'u8'): 'u8',
100+
}
101+
102+
mixed_signed_unsigned_promotion_table = {
103+
('i1', 'u1'): 'i2',
104+
('i1', 'u2'): 'i4',
105+
('i1', 'u4'): 'i8',
106+
('i2', 'u1'): 'i2',
107+
('i2', 'u2'): 'i4',
108+
('i2', 'u4'): 'i8',
109+
('i4', 'u1'): 'i4',
110+
('i4', 'u2'): 'i4',
111+
('i4', 'u4'): 'i8',
112+
('i8', 'u1'): 'i8',
113+
('i8', 'u2'): 'i8',
114+
('i8', 'u4'): 'i8',
115+
}
116+
117+
flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()}
118+
119+
float_promotion_table = {
120+
('f4', 'f4'): 'f4',
121+
('f4', 'f8'): 'f8',
122+
('f8', 'f4'): 'f8',
123+
('f8', 'f8'): 'f8',
124+
}
125+
126+
boolean_promotion_table = {
127+
('b', 'b'): 'b',
128+
}
129+
130+
promotion_table = {
131+
**signed_integer_promotion_table,
132+
**unsigned_integer_promotion_table,
133+
**mixed_signed_unsigned_promotion_table,
134+
**flipped_mixed_signed_unsigned_promotion_table,
135+
**float_promotion_table,
136+
**boolean_promotion_table,
137+
}
138+
139+
input_types = {
140+
'any': sorted(set(promotion_table.values())),
141+
'boolean': sorted(set(boolean_promotion_table.values())),
142+
'floating': sorted(set(float_promotion_table.values())),
143+
'integer': sorted(set({**signed_integer_promotion_table,
144+
**unsigned_integer_promotion_table}.values())),
145+
'integer_or_boolean': sorted(set({**signed_integer_promotion_table,
146+
**unsigned_integer_promotion_table,
147+
**boolean_promotion_table}.values())),
148+
'numeric': sorted(set({**float_promotion_table,
149+
**signed_integer_promotion_table,
150+
**unsigned_integer_promotion_table}.values())),
151+
}
152+
153+
dtypes_to_scalars = {
154+
'b': [bool],
155+
'i1': [int],
156+
'i2': [int],
157+
'i4': [int],
158+
'i8': [int],
159+
# Note: unsigned int dtypes only correspond to positive integers
160+
'u1': [int],
161+
'u2': [int],
162+
'u4': [int],
163+
'u8': [int],
164+
'f4': [int, float],
165+
'f8': [int, float],
166+
}
167+
168+
elementwise_function_input_types = {
169+
'abs': 'numeric',
170+
'acos': 'floating',
171+
'acosh': 'floating',
172+
'add': 'numeric',
173+
'asin': 'floating',
174+
'asinh': 'floating',
175+
'atan': 'floating',
176+
'atan2': 'floating',
177+
'atanh': 'floating',
178+
'bitwise_and': 'integer_or_boolean',
179+
'bitwise_invert': 'integer_or_boolean',
180+
'bitwise_left_shift': 'integer',
181+
'bitwise_or': 'integer_or_boolean',
182+
'bitwise_right_shift': 'integer',
183+
'bitwise_xor': 'integer_or_boolean',
184+
'ceil': 'numeric',
185+
'cos': 'floating',
186+
'cosh': 'floating',
187+
'divide': 'floating',
188+
'equal': 'any',
189+
'exp': 'floating',
190+
'expm1': 'floating',
191+
'floor': 'numeric',
192+
'floor_divide': 'numeric',
193+
'greater': 'numeric',
194+
'greater_equal': 'numeric',
195+
'isfinite': 'numeric',
196+
'isinf': 'numeric',
197+
'isnan': 'numeric',
198+
'less': 'numeric',
199+
'less_equal': 'numeric',
200+
'log': 'floating',
201+
'logaddexp': 'floating',
202+
'log10': 'floating',
203+
'log1p': 'floating',
204+
'log2': 'floating',
205+
'logical_and': 'boolean',
206+
'logical_not': 'boolean',
207+
'logical_or': 'boolean',
208+
'logical_xor': 'boolean',
209+
'multiply': 'numeric',
210+
'negative': 'numeric',
211+
'not_equal': 'any',
212+
'positive': 'numeric',
213+
'pow': 'floating',
214+
'remainder': 'numeric',
215+
'round': 'numeric',
216+
'sign': 'numeric',
217+
'sin': 'floating',
218+
'sinh': 'floating',
219+
'sqrt': 'floating',
220+
'square': 'numeric',
221+
'subtract': 'numeric',
222+
'tan': 'floating',
223+
'tanh': 'floating',
224+
'trunc': 'numeric',
225+
}
226+
227+
elementwise_function_output_types = {
228+
'abs': 'promoted',
229+
'acos': 'promoted',
230+
'acosh': 'promoted',
231+
'add': 'promoted',
232+
'asin': 'promoted',
233+
'asinh': 'promoted',
234+
'atan': 'promoted',
235+
'atan2': 'promoted',
236+
'atanh': 'promoted',
237+
'bitwise_and': 'promoted',
238+
'bitwise_invert': 'promoted',
239+
'bitwise_left_shift': 'promoted',
240+
'bitwise_or': 'promoted',
241+
'bitwise_right_shift': 'promoted',
242+
'bitwise_xor': 'promoted',
243+
'ceil': 'promoted',
244+
'cos': 'promoted',
245+
'cosh': 'promoted',
246+
'divide': 'promoted',
247+
'equal': 'bool',
248+
'exp': 'promoted',
249+
'expm1': 'promoted',
250+
'floor': 'promoted',
251+
'floor_divide': 'promoted',
252+
'greater': 'bool',
253+
'greater_equal': 'bool',
254+
'isfinite': 'bool',
255+
'isinf': 'bool',
256+
'isnan': 'bool',
257+
'less': 'bool',
258+
'less_equal': 'bool',
259+
'log': 'promoted',
260+
'logaddexp': 'promoted',
261+
'log10': 'promoted',
262+
'log1p': 'promoted',
263+
'log2': 'promoted',
264+
'logical_and': 'bool',
265+
'logical_not': 'bool',
266+
'logical_or': 'bool',
267+
'logical_xor': 'bool',
268+
'multiply': 'promoted',
269+
'negative': 'promoted',
270+
'not_equal': 'bool',
271+
'positive': 'promoted',
272+
'pow': 'promoted',
273+
'remainder': 'promoted',
274+
'round': 'promoted',
275+
'sign': 'promoted',
276+
'sin': 'promoted',
277+
'sinh': 'promoted',
278+
'sqrt': 'promoted',
279+
'square': 'promoted',
280+
'subtract': 'promoted',
281+
'tan': 'promoted',
282+
'tanh': 'promoted',
283+
'trunc': 'promoted',
284+
}
285+
286+
binary_operators = {
287+
'__add__': '+',
288+
'__and__': '&',
289+
'__eq__': '==',
290+
'__floordiv__': '//',
291+
'__ge__': '>=',
292+
'__gt__': '>',
293+
'__le__': '<=',
294+
'__lshift__': '<<',
295+
'__lt__': '<',
296+
'__matmul__': '@',
297+
'__mod__': '%',
298+
'__mul__': '*',
299+
'__ne__': '!=',
300+
'__or__': '|',
301+
'__pow__': '**',
302+
'__rshift__': '>>',
303+
'__sub__': '-',
304+
'__truediv__': '/',
305+
'__xor__': '^',
306+
}
307+
308+
unary_operators = {
309+
'__abs__': 'abs()',
310+
'__invert__': '~',
311+
'__neg__': '-',
312+
'__pos__': '+',
313+
}
314+
315+
316+
operators_to_functions = {
317+
'__abs__': 'abs',
318+
'__add__': 'add',
319+
'__and__': 'bitwise_and',
320+
'__eq__': 'equal',
321+
'__floordiv__': 'floor_divide',
322+
'__ge__': 'greater_equal',
323+
'__gt__': 'greater',
324+
'__le__': 'less_equal',
325+
'__lshift__': 'bitwise_left_shift',
326+
'__lt__': 'less',
327+
'__matmul__': 'matmul',
328+
'__mod__': 'remainder',
329+
'__mul__': 'multiply',
330+
'__ne__': 'not_equal',
331+
'__or__': 'bitwise_or',
332+
'__pow__': 'pow',
333+
'__rshift__': 'bitwise_right_shift',
334+
'__sub__': 'subtract',
335+
'__truediv__': 'divide',
336+
'__xor__': 'bitwise_xor',
337+
'__invert__': 'bitwise_invert',
338+
'__neg__': 'negative',
339+
'__pos__': 'positive',
340+
}

0 commit comments

Comments
 (0)