Skip to content

Commit c07ba0e

Browse files
committed
Rudimentary elementwise parameters refactor
1 parent c9d958e commit c07ba0e

File tree

4 files changed

+120
-169
lines changed

4 files changed

+120
-169
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
__all__ = [
55
'dtypes_to_scalars',
6-
'input_types',
6+
'category_to_dtypes',
77
'promotion_table',
88
'dtype_nbits',
99
'dtype_signed',
10+
'func_in_categories',
11+
'func_out_categories',
1012
'binary_operators',
1113
'unary_operators',
1214
'operators_to_functions',
13-
'elementwise_function_input_types',
14-
'elementwise_function_output_types',
1515
]
1616

1717

@@ -30,7 +30,7 @@
3030
}
3131

3232

33-
input_types = {
33+
category_to_dtypes = {
3434
'any': all_dtypes,
3535
'boolean': (xp.bool,),
3636
'floating': float_dtypes,
@@ -84,6 +84,7 @@
8484
promotion_table = {
8585
(xp.bool, xp.bool): xp.bool,
8686
**_numeric_promotions,
87+
# TODO: dont unpack pairs of the same dtype
8788
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
8889
}
8990

@@ -102,7 +103,7 @@
102103
}
103104

104105

105-
elementwise_function_input_types = {
106+
func_in_categories = {
106107
'abs': 'numeric',
107108
'acos': 'floating',
108109
'acosh': 'floating',
@@ -162,7 +163,7 @@
162163
}
163164

164165

165-
elementwise_function_output_types = {
166+
func_out_categories = {
166167
'abs': 'promoted',
167168
'acos': 'promoted',
168169
'acosh': 'promoted',

array_api_tests/test_broadcasting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES
1111
from .pytest_helpers import raises, doesnt_raise, nargs
1212

13-
from .dtype_helpers import elementwise_function_input_types, input_types
13+
from .dtype_helpers import func_in_categories, category_to_dtypes
1414
from .function_stubs import elementwise_functions
1515
from . import _array_module
1616
from ._array_module import ones, _UndefinedStub
@@ -115,7 +115,7 @@ def test_broadcasting_hypothesis(func_name, shape1, shape2, data):
115115
# Internal consistency checks
116116
assert nargs(func_name) == 2
117117

118-
dtype = data.draw(sampled_from(input_types[elementwise_function_input_types[func_name]]))
118+
dtype = data.draw(sampled_from(category_to_dtypes[func_in_categories[func_name]]))
119119
if FILTER_UNDEFINED_DTYPES:
120120
assume(not isinstance(dtype, _UndefinedStub))
121121
func = getattr(_array_module, func_name)

array_api_tests/test_signatures.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from ._array_module import mod, mod_name, ones, eye, float64, bool, int64
66
from .pytest_helpers import raises, doesnt_raise
7-
from .dtype_helpers import elementwise_function_input_types, operators_to_functions
7+
from .dtype_helpers import func_in_categories, operators_to_functions
88

99
from . import function_stubs
1010

@@ -163,9 +163,9 @@ def test_function_positional_args(name):
163163
n = operators_to_functions[name[:2] + name[3:]]
164164
else:
165165
n = operators_to_functions.get(name, name)
166-
if 'boolean' in elementwise_function_input_types.get(n, 'floating'):
166+
if 'boolean' in func_in_categories.get(n, 'floating'):
167167
dtype = bool
168-
elif 'integer' in elementwise_function_input_types.get(n, 'floating'):
168+
elif 'integer' in func_in_categories.get(n, 'floating'):
169169
dtype = int64
170170

171171
if array_method(name):

0 commit comments

Comments
 (0)