10
10
"""
11
11
12
12
import math
13
+ from enum import Enum , auto
13
14
from typing import Callable , List , Sequence , Union
14
15
15
16
import pytest
25
26
from . import xps
26
27
from .typing import Array , DataType , Param , Scalar
27
28
28
-
29
29
# We might as well use this implementation rather than xp.broadcast_shapes()
30
30
from .test_broadcasting import broadcast_shapes
31
31
@@ -74,38 +74,53 @@ def op_func(x: Array) -> Array:
74
74
]
75
75
76
76
77
+ class FuncType (Enum ):
78
+ FUNC = auto ()
79
+ OP = auto ()
80
+ IOP = auto ()
81
+
82
+ @classmethod
83
+ def from_name (cls , name : str ):
84
+ if name in dh .binary_op_to_symbol .keys ():
85
+ return cls .OP
86
+ elif name in dh .inplace_op_to_symbol .keys ():
87
+ return cls .IOP
88
+ else :
89
+ return cls .FUNC
90
+
91
+
77
92
def _make_binary_param (
78
93
func_name : str , right_is_scalar : bool , dtypes : Sequence [DataType ]
79
94
) -> BinaryParam :
80
- if func_name in dh .binary_op_to_symbol .keys ():
81
- func_type = "op"
82
- elif func_name in dh .inplace_op_to_symbol .keys ():
83
- func_type = "iop"
84
- else :
85
- func_type = "func"
95
+ func_type = FuncType .from_name (func_name )
86
96
87
- left_sym , right_sym = ("x" , "s" ) if right_is_scalar else ("x1" , "x2" )
97
+ if right_is_scalar :
98
+ left_sym = "x"
99
+ right_sym = "s"
100
+ else :
101
+ left_sym = "x1"
102
+ right_sym = "x2"
88
103
89
104
dtypes_strat = st .sampled_from (dtypes )
90
105
shared_dtypes = st .shared (dtypes_strat )
91
106
if right_is_scalar :
92
107
left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes ())
93
108
right_strat = shared_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
94
109
else :
95
- if func_type == "iop" :
110
+ if func_type is FuncType . IOP :
96
111
shared_shapes = st .shared (hh .shapes ())
97
112
left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
98
113
right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
99
114
else :
100
115
left_strat , right_strat = hh .two_mutual_arrays (dtypes )
101
116
102
- if func_type == "func" :
117
+ if func_type is FuncType . FUNC :
103
118
func = getattr (xp , func_name )
104
119
else :
105
120
op_sym = all_op_to_symbol [func_name ]
106
121
expr = f"{ left_sym } { op_sym } { right_sym } "
107
122
108
- if func_type == "op" :
123
+ if func_type is FuncType . OP :
109
124
110
125
def func (l : Array , r : Union [Scalar , Array ]) -> Array :
111
126
locals_ = {}
@@ -124,7 +139,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array:
124
139
125
140
func .__name__ = func_name # for repr
126
141
127
- if func_type == "iop" :
142
+ if func_type is FuncType . IOP :
128
143
res_name = left_sym
129
144
else :
130
145
res_name = "out"
0 commit comments