|
35 | 35 | UnaryParam = Param[str, Callable[[Array], Array], st.SearchStrategy[Array]]
|
36 | 36 |
|
37 | 37 |
|
38 |
| -def make_unary_params(func_name: str, dtypes: Sequence[DataType]) -> List[UnaryParam]: |
| 38 | +def make_unary_params( |
| 39 | + elwise_func_name: str, dtypes: Sequence[DataType] |
| 40 | +) -> List[UnaryParam]: |
39 | 41 | strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes())
|
40 |
| - |
41 |
| - func = getattr(xp, func_name) |
42 |
| - op = func_to_op[func_name] |
43 |
| - |
44 |
| - def op_func(x: Array) -> Array: |
45 |
| - return getattr(x, op)() |
46 |
| - |
| 42 | + func = getattr(xp, elwise_func_name) |
| 43 | + op_name = func_to_op[elwise_func_name] |
| 44 | + op = lambda x: getattr(x, op_name)() |
47 | 45 | return [
|
48 |
| - pytest.param(func_name, func, strat, id=func_name), |
49 |
| - pytest.param(op, op_func, strat, id=op), |
| 46 | + pytest.param(elwise_func_name, func, strat, id=elwise_func_name), |
| 47 | + pytest.param(op_name, op, strat, id=op_name), |
50 | 48 | ]
|
51 | 49 |
|
52 | 50 |
|
@@ -79,101 +77,90 @@ class FuncType(Enum):
|
79 | 77 | OP = auto()
|
80 | 78 | IOP = auto()
|
81 | 79 |
|
82 |
| - @classmethod |
83 |
| - def from_name(cls, name: str): |
84 |
| - if name in dh.binary_op_to_symbol.keys(): |
85 |
| - return cls.OP |
86 |
| - elif name in dh.inplace_op_to_symbol.keys(): |
87 |
| - return cls.IOP |
88 |
| - else: |
89 |
| - return cls.FUNC |
90 |
| - |
91 |
| - |
92 |
| -def _make_binary_param( |
93 |
| - func_name: str, right_is_scalar: bool, dtypes: Sequence[DataType] |
94 |
| -) -> BinaryParam: |
95 |
| - func_type = FuncType.from_name(func_name) |
96 |
| - |
97 |
| - if right_is_scalar: |
98 |
| - left_sym = "x" |
99 |
| - right_sym = "s" |
100 |
| - else: |
101 |
| - left_sym = "x1" |
102 |
| - right_sym = "x2" |
103 | 80 |
|
| 81 | +def make_binary_params( |
| 82 | + elwise_func_name: str, dtypes: Sequence[DataType] |
| 83 | +) -> List[BinaryParam]: |
104 | 84 | dtypes_strat = st.sampled_from(dtypes)
|
105 |
| - shared_dtypes = st.shared(dtypes_strat) |
106 |
| - if right_is_scalar: |
107 |
| - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) |
108 |
| - right_strat = shared_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) |
109 |
| - else: |
110 |
| - if func_type is FuncType.IOP: |
111 |
| - shared_shapes = st.shared(hh.shapes()) |
112 |
| - left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
113 |
| - right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
114 |
| - else: |
115 |
| - left_strat, right_strat = hh.two_mutual_arrays(dtypes) |
116 |
| - |
117 |
| - if func_type is FuncType.FUNC: |
118 |
| - func = getattr(xp, func_name) |
119 |
| - else: |
120 |
| - op_sym = all_op_to_symbol[func_name] |
121 |
| - expr = f"{left_sym} {op_sym} {right_sym}" |
122 |
| - |
123 |
| - if func_type is FuncType.OP: |
124 |
| - |
125 |
| - def func(l: Array, r: Union[Scalar, Array]) -> Array: |
126 |
| - locals_ = {} |
127 |
| - locals_[left_sym] = l |
128 |
| - locals_[right_sym] = r |
129 |
| - return eval(expr, locals_) |
130 | 85 |
|
| 86 | + def make_param( |
| 87 | + func_name: str, func_type: FuncType, right_is_scalar: bool |
| 88 | + ) -> BinaryParam: |
| 89 | + if right_is_scalar: |
| 90 | + left_sym = "x" |
| 91 | + right_sym = "s" |
131 | 92 | else:
|
| 93 | + left_sym = "x1" |
| 94 | + right_sym = "x2" |
| 95 | + |
| 96 | + shared_dtypes = st.shared(dtypes_strat) |
| 97 | + if right_is_scalar: |
| 98 | + left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) |
| 99 | + right_strat = shared_dtypes.flatmap( |
| 100 | + lambda d: xps.from_dtype(d, **finite_kw) |
| 101 | + ) |
| 102 | + else: |
| 103 | + if func_type is FuncType.IOP: |
| 104 | + shared_shapes = st.shared(hh.shapes()) |
| 105 | + left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
| 106 | + right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
| 107 | + else: |
| 108 | + left_strat, right_strat = hh.two_mutual_arrays(dtypes) |
| 109 | + |
| 110 | + if func_type is FuncType.FUNC: |
| 111 | + func = getattr(xp, func_name) |
| 112 | + else: |
| 113 | + op_sym = all_op_to_symbol[func_name] |
| 114 | + expr = f"{left_sym} {op_sym} {right_sym}" |
| 115 | + if func_type is FuncType.OP: |
132 | 116 |
|
133 |
| - def func(l: Array, r: Union[Scalar, Array]) -> Array: |
134 |
| - locals_ = {} |
135 |
| - locals_[left_sym] = ah.asarray(l, copy=True) # prevents left mutating |
136 |
| - locals_[right_sym] = r |
137 |
| - exec(expr, locals_) |
138 |
| - return locals_[left_sym] |
| 117 | + def func(l: Array, r: Union[Scalar, Array]) -> Array: |
| 118 | + locals_ = {} |
| 119 | + locals_[left_sym] = l |
| 120 | + locals_[right_sym] = r |
| 121 | + return eval(expr, locals_) |
139 | 122 |
|
140 |
| - func.__name__ = func_name # for repr |
| 123 | + else: |
141 | 124 |
|
142 |
| - if func_type is FuncType.IOP: |
143 |
| - res_name = left_sym |
144 |
| - else: |
145 |
| - res_name = "out" |
| 125 | + def func(l: Array, r: Union[Scalar, Array]) -> Array: |
| 126 | + locals_ = {} |
| 127 | + locals_[left_sym] = ah.asarray( |
| 128 | + l, copy=True |
| 129 | + ) # prevents left mutating |
| 130 | + locals_[right_sym] = r |
| 131 | + exec(expr, locals_) |
| 132 | + return locals_[left_sym] |
146 | 133 |
|
147 |
| - f_id = func_name |
148 |
| - if right_is_scalar: |
149 |
| - f_id += "(x, s)" |
150 |
| - else: |
151 |
| - f_id += "(x1, x2)" |
152 |
| - |
153 |
| - return pytest.param( |
154 |
| - func_name, |
155 |
| - func, |
156 |
| - left_sym, |
157 |
| - left_strat, |
158 |
| - right_sym, |
159 |
| - right_strat, |
160 |
| - right_is_scalar, |
161 |
| - res_name, |
162 |
| - id=f_id, |
163 |
| - ) |
| 134 | + func.__name__ = func_name # for repr |
164 | 135 |
|
| 136 | + if func_type is FuncType.IOP: |
| 137 | + res_name = left_sym |
| 138 | + else: |
| 139 | + res_name = "out" |
| 140 | + |
| 141 | + return pytest.param( |
| 142 | + func_name, |
| 143 | + func, |
| 144 | + left_sym, |
| 145 | + left_strat, |
| 146 | + right_sym, |
| 147 | + right_strat, |
| 148 | + right_is_scalar, |
| 149 | + res_name, |
| 150 | + id=f"{func_name}({left_sym}, {right_sym})", |
| 151 | + ) |
165 | 152 |
|
166 |
| -def make_binary_params(func_name: str, dtypes: Sequence[DataType]) -> List[BinaryParam]: |
167 |
| - op = func_to_op[func_name] |
| 153 | + op_name = func_to_op[elwise_func_name] |
168 | 154 | params = [
|
169 |
| - _make_binary_param(func_name, False, dtypes), |
170 |
| - _make_binary_param(op, False, dtypes), |
171 |
| - _make_binary_param(op, True, dtypes), |
| 155 | + make_param(elwise_func_name, FuncType.FUNC, False), |
| 156 | + make_param(op_name, FuncType.OP, False), |
| 157 | + make_param(op_name, FuncType.OP, True), |
172 | 158 | ]
|
173 |
| - iop = f"__i{op[2:]}" |
174 |
| - if iop in dh.inplace_op_to_symbol.keys(): |
175 |
| - params.append(_make_binary_param(iop, False, dtypes)) |
176 |
| - params.append(_make_binary_param(iop, True, dtypes)) |
| 159 | + iop_name = f"__i{op_name[2:]}" |
| 160 | + if iop_name in dh.inplace_op_to_symbol.keys(): |
| 161 | + params.append(make_param(iop_name, FuncType.IOP, False)) |
| 162 | + params.append(make_param(iop_name, FuncType.IOP, True)) |
| 163 | + |
177 | 164 | return params
|
178 | 165 |
|
179 | 166 |
|
|
0 commit comments