Skip to content

Commit 4bb70cf

Browse files
committed
Clean param generating logic
1 parent ca81c19 commit 4bb70cf

File tree

1 file changed

+80
-93
lines changed

1 file changed

+80
-93
lines changed

array_api_tests/test_elementwise_functions.py

Lines changed: 80 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,16 @@
3535
UnaryParam = Param[str, Callable[[Array], Array], st.SearchStrategy[Array]]
3636

3737

38-
def make_unary_params(func_name: str, dtypes: Sequence[DataType]) -> List[UnaryParam]:
38+
def make_unary_params(
39+
elwise_func_name: str, dtypes: Sequence[DataType]
40+
) -> List[UnaryParam]:
3941
strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes())
40-
41-
func = getattr(xp, func_name)
42-
op = func_to_op[func_name]
43-
44-
def op_func(x: Array) -> Array:
45-
return getattr(x, op)()
46-
42+
func = getattr(xp, elwise_func_name)
43+
op_name = func_to_op[elwise_func_name]
44+
op = lambda x: getattr(x, op_name)()
4745
return [
48-
pytest.param(func_name, func, strat, id=func_name),
49-
pytest.param(op, op_func, strat, id=op),
46+
pytest.param(elwise_func_name, func, strat, id=elwise_func_name),
47+
pytest.param(op_name, op, strat, id=op_name),
5048
]
5149

5250

@@ -79,101 +77,90 @@ class FuncType(Enum):
7977
OP = auto()
8078
IOP = auto()
8179

82-
@classmethod
83-
def from_name(cls, name: str):
84-
if name in dh.binary_op_to_symbol.keys():
85-
return cls.OP
86-
elif name in dh.inplace_op_to_symbol.keys():
87-
return cls.IOP
88-
else:
89-
return cls.FUNC
90-
91-
92-
def _make_binary_param(
93-
func_name: str, right_is_scalar: bool, dtypes: Sequence[DataType]
94-
) -> BinaryParam:
95-
func_type = FuncType.from_name(func_name)
96-
97-
if right_is_scalar:
98-
left_sym = "x"
99-
right_sym = "s"
100-
else:
101-
left_sym = "x1"
102-
right_sym = "x2"
10380

81+
def make_binary_params(
82+
elwise_func_name: str, dtypes: Sequence[DataType]
83+
) -> List[BinaryParam]:
10484
dtypes_strat = st.sampled_from(dtypes)
105-
shared_dtypes = st.shared(dtypes_strat)
106-
if right_is_scalar:
107-
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes())
108-
right_strat = shared_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
109-
else:
110-
if func_type is FuncType.IOP:
111-
shared_shapes = st.shared(hh.shapes())
112-
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
113-
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
114-
else:
115-
left_strat, right_strat = hh.two_mutual_arrays(dtypes)
116-
117-
if func_type is FuncType.FUNC:
118-
func = getattr(xp, func_name)
119-
else:
120-
op_sym = all_op_to_symbol[func_name]
121-
expr = f"{left_sym} {op_sym} {right_sym}"
122-
123-
if func_type is FuncType.OP:
124-
125-
def func(l: Array, r: Union[Scalar, Array]) -> Array:
126-
locals_ = {}
127-
locals_[left_sym] = l
128-
locals_[right_sym] = r
129-
return eval(expr, locals_)
13085

86+
def make_param(
87+
func_name: str, func_type: FuncType, right_is_scalar: bool
88+
) -> BinaryParam:
89+
if right_is_scalar:
90+
left_sym = "x"
91+
right_sym = "s"
13192
else:
93+
left_sym = "x1"
94+
right_sym = "x2"
95+
96+
shared_dtypes = st.shared(dtypes_strat)
97+
if right_is_scalar:
98+
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes())
99+
right_strat = shared_dtypes.flatmap(
100+
lambda d: xps.from_dtype(d, **finite_kw)
101+
)
102+
else:
103+
if func_type is FuncType.IOP:
104+
shared_shapes = st.shared(hh.shapes())
105+
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
106+
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
107+
else:
108+
left_strat, right_strat = hh.two_mutual_arrays(dtypes)
109+
110+
if func_type is FuncType.FUNC:
111+
func = getattr(xp, func_name)
112+
else:
113+
op_sym = all_op_to_symbol[func_name]
114+
expr = f"{left_sym} {op_sym} {right_sym}"
115+
if func_type is FuncType.OP:
132116

133-
def func(l: Array, r: Union[Scalar, Array]) -> Array:
134-
locals_ = {}
135-
locals_[left_sym] = ah.asarray(l, copy=True) # prevents left mutating
136-
locals_[right_sym] = r
137-
exec(expr, locals_)
138-
return locals_[left_sym]
117+
def func(l: Array, r: Union[Scalar, Array]) -> Array:
118+
locals_ = {}
119+
locals_[left_sym] = l
120+
locals_[right_sym] = r
121+
return eval(expr, locals_)
139122

140-
func.__name__ = func_name # for repr
123+
else:
141124

142-
if func_type is FuncType.IOP:
143-
res_name = left_sym
144-
else:
145-
res_name = "out"
125+
def func(l: Array, r: Union[Scalar, Array]) -> Array:
126+
locals_ = {}
127+
locals_[left_sym] = ah.asarray(
128+
l, copy=True
129+
) # prevents left mutating
130+
locals_[right_sym] = r
131+
exec(expr, locals_)
132+
return locals_[left_sym]
146133

147-
f_id = func_name
148-
if right_is_scalar:
149-
f_id += "(x, s)"
150-
else:
151-
f_id += "(x1, x2)"
152-
153-
return pytest.param(
154-
func_name,
155-
func,
156-
left_sym,
157-
left_strat,
158-
right_sym,
159-
right_strat,
160-
right_is_scalar,
161-
res_name,
162-
id=f_id,
163-
)
134+
func.__name__ = func_name # for repr
164135

136+
if func_type is FuncType.IOP:
137+
res_name = left_sym
138+
else:
139+
res_name = "out"
140+
141+
return pytest.param(
142+
func_name,
143+
func,
144+
left_sym,
145+
left_strat,
146+
right_sym,
147+
right_strat,
148+
right_is_scalar,
149+
res_name,
150+
id=f"{func_name}({left_sym}, {right_sym})",
151+
)
165152

166-
def make_binary_params(func_name: str, dtypes: Sequence[DataType]) -> List[BinaryParam]:
167-
op = func_to_op[func_name]
153+
op_name = func_to_op[elwise_func_name]
168154
params = [
169-
_make_binary_param(func_name, False, dtypes),
170-
_make_binary_param(op, False, dtypes),
171-
_make_binary_param(op, True, dtypes),
155+
make_param(elwise_func_name, FuncType.FUNC, False),
156+
make_param(op_name, FuncType.OP, False),
157+
make_param(op_name, FuncType.OP, True),
172158
]
173-
iop = f"__i{op[2:]}"
174-
if iop in dh.inplace_op_to_symbol.keys():
175-
params.append(_make_binary_param(iop, False, dtypes))
176-
params.append(_make_binary_param(iop, True, dtypes))
159+
iop_name = f"__i{op_name[2:]}"
160+
if iop_name in dh.inplace_op_to_symbol.keys():
161+
params.append(make_param(iop_name, FuncType.IOP, False))
162+
params.append(make_param(iop_name, FuncType.IOP, True))
163+
177164
return params
178165

179166

0 commit comments

Comments
 (0)