Skip to content

Commit add41f0

Browse files
committed
Create dh.func_in_dtypes from parsing the spec
1 parent 7cb3dc0 commit add41f0

File tree

1 file changed

+27
-63
lines changed

1 file changed

+27
-63
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 27 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import re
12
from functools import lru_cache
2-
from typing import NamedTuple, Tuple, Union
3+
from inspect import signature
4+
from typing import Dict, NamedTuple, Tuple, Union
35
from warnings import warn
46

57
from . import _array_module as xp
68
from ._array_module import _UndefinedStub
9+
from .stubs import name_to_func
710
from .typing import DataType, ScalarType
811

912
__all__ = [
@@ -99,8 +102,8 @@ class MinMax(NamedTuple):
99102
xp.uint16: MinMax(0, +65_535),
100103
xp.uint32: MinMax(0, +4_294_967_295),
101104
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
102-
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
103-
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
105+
xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
106+
xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
104107
}
105108

106109
dtype_nbits = {
@@ -196,67 +199,28 @@ def result_type(*dtypes: DataType):
196199
return result
197200

198201

199-
func_in_dtypes = {
200-
# elementwise
201-
"abs": numeric_dtypes,
202-
"acos": float_dtypes,
203-
"acosh": float_dtypes,
204-
"add": numeric_dtypes,
205-
"asin": float_dtypes,
206-
"asinh": float_dtypes,
207-
"atan": float_dtypes,
208-
"atan2": float_dtypes,
209-
"atanh": float_dtypes,
210-
"bitwise_and": bool_and_all_int_dtypes,
211-
"bitwise_invert": bool_and_all_int_dtypes,
212-
"bitwise_left_shift": all_int_dtypes,
213-
"bitwise_or": bool_and_all_int_dtypes,
214-
"bitwise_right_shift": all_int_dtypes,
215-
"bitwise_xor": bool_and_all_int_dtypes,
216-
"ceil": numeric_dtypes,
217-
"cos": float_dtypes,
218-
"cosh": float_dtypes,
219-
"divide": float_dtypes,
220-
"equal": all_dtypes,
221-
"exp": float_dtypes,
222-
"expm1": float_dtypes,
223-
"floor": numeric_dtypes,
224-
"floor_divide": numeric_dtypes,
225-
"greater": numeric_dtypes,
226-
"greater_equal": numeric_dtypes,
227-
"isfinite": numeric_dtypes,
228-
"isinf": numeric_dtypes,
229-
"isnan": numeric_dtypes,
230-
"less": numeric_dtypes,
231-
"less_equal": numeric_dtypes,
232-
"log": float_dtypes,
233-
"logaddexp": float_dtypes,
234-
"log10": float_dtypes,
235-
"log1p": float_dtypes,
236-
"log2": float_dtypes,
237-
"logical_and": (xp.bool,),
238-
"logical_not": (xp.bool,),
239-
"logical_or": (xp.bool,),
240-
"logical_xor": (xp.bool,),
241-
"multiply": numeric_dtypes,
242-
"negative": numeric_dtypes,
243-
"not_equal": all_dtypes,
244-
"positive": numeric_dtypes,
245-
"pow": numeric_dtypes,
246-
"remainder": numeric_dtypes,
247-
"round": numeric_dtypes,
248-
"sign": numeric_dtypes,
249-
"sin": float_dtypes,
250-
"sinh": float_dtypes,
251-
"sqrt": float_dtypes,
252-
"square": numeric_dtypes,
253-
"subtract": numeric_dtypes,
254-
"tan": float_dtypes,
255-
"tanh": float_dtypes,
256-
"trunc": numeric_dtypes,
257-
# searching
258-
"where": all_dtypes,
202+
r_in_dtypes = re.compile("x1?: array\n.+Should have an? (.+) data type.")
203+
r_int_note = re.compile(
204+
"If one or both of the input arrays have integer data types, "
205+
"the result is implementation-dependent"
206+
)
207+
category_to_dtypes = {
208+
"boolean": (xp.bool,),
209+
"integer": all_int_dtypes,
210+
"floating-point": float_dtypes,
211+
"numeric": numeric_dtypes,
212+
"integer or boolean": bool_and_all_int_dtypes,
259213
}
214+
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
215+
for name, func in name_to_func.items():
216+
if m := r_in_dtypes.search(func.__doc__):
217+
dtype_category = m.group(1)
218+
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
219+
dtype_category = "floating-point"
220+
dtypes = category_to_dtypes[dtype_category]
221+
func_in_dtypes[name] = dtypes
222+
elif any("x" in name for name in signature(func).parameters.keys()):
223+
func_in_dtypes[name] = all_dtypes
260224

261225

262226
func_returns_bool = {

0 commit comments

Comments
 (0)