Skip to content

Commit ca81c19

Browse files
committed
Use enum for func types
1 parent 84a45d2 commit ca81c19

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
import math
13+
from enum import Enum, auto
1314
from typing import Callable, List, Sequence, Union
1415

1516
import pytest
@@ -25,7 +26,6 @@
2526
from . import xps
2627
from .typing import Array, DataType, Param, Scalar
2728

28-
2929
# We might as well use this implementation rather than xp.broadcast_shapes()
3030
from .test_broadcasting import broadcast_shapes
3131

@@ -74,38 +74,53 @@ def op_func(x: Array) -> Array:
7474
]
7575

7676

77+
class FuncType(Enum):
78+
FUNC = auto()
79+
OP = auto()
80+
IOP = auto()
81+
82+
@classmethod
83+
def from_name(cls, name: str):
84+
if name in dh.binary_op_to_symbol.keys():
85+
return cls.OP
86+
elif name in dh.inplace_op_to_symbol.keys():
87+
return cls.IOP
88+
else:
89+
return cls.FUNC
90+
91+
7792
def _make_binary_param(
7893
func_name: str, right_is_scalar: bool, dtypes: Sequence[DataType]
7994
) -> BinaryParam:
80-
if func_name in dh.binary_op_to_symbol.keys():
81-
func_type = "op"
82-
elif func_name in dh.inplace_op_to_symbol.keys():
83-
func_type = "iop"
84-
else:
85-
func_type = "func"
95+
func_type = FuncType.from_name(func_name)
8696

87-
left_sym, right_sym = ("x", "s") if right_is_scalar else ("x1", "x2")
97+
if right_is_scalar:
98+
left_sym = "x"
99+
right_sym = "s"
100+
else:
101+
left_sym = "x1"
102+
right_sym = "x2"
88103

89104
dtypes_strat = st.sampled_from(dtypes)
90105
shared_dtypes = st.shared(dtypes_strat)
91106
if right_is_scalar:
92107
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes())
93108
right_strat = shared_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
94109
else:
95-
if func_type == "iop":
110+
if func_type is FuncType.IOP:
96111
shared_shapes = st.shared(hh.shapes())
97112
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
98113
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
99114
else:
100115
left_strat, right_strat = hh.two_mutual_arrays(dtypes)
101116

102-
if func_type == "func":
117+
if func_type is FuncType.FUNC:
103118
func = getattr(xp, func_name)
104119
else:
105120
op_sym = all_op_to_symbol[func_name]
106121
expr = f"{left_sym} {op_sym} {right_sym}"
107122

108-
if func_type == "op":
123+
if func_type is FuncType.OP:
109124

110125
def func(l: Array, r: Union[Scalar, Array]) -> Array:
111126
locals_ = {}
@@ -124,7 +139,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array:
124139

125140
func.__name__ = func_name # for repr
126141

127-
if func_type == "iop":
142+
if func_type is FuncType.IOP:
128143
res_name = left_sym
129144
else:
130145
res_name = "out"

0 commit comments

Comments
 (0)