|
| 1 | +import re |
1 | 2 | from functools import lru_cache
|
2 |
| -from typing import NamedTuple, Tuple, Union |
| 3 | +from inspect import signature |
| 4 | +from typing import Dict, NamedTuple, Tuple, Union |
3 | 5 | from warnings import warn
|
4 | 6 |
|
5 | 7 | from . import _array_module as xp
|
6 | 8 | from ._array_module import _UndefinedStub
|
| 9 | +from .stubs import name_to_func |
7 | 10 | from .typing import DataType, ScalarType
|
8 | 11 |
|
9 | 12 | __all__ = [
|
@@ -99,8 +102,8 @@ class MinMax(NamedTuple):
|
99 | 102 | xp.uint16: MinMax(0, +65_535),
|
100 | 103 | xp.uint32: MinMax(0, +4_294_967_295),
|
101 | 104 | xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
|
102 |
| - xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38), |
103 |
| - xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308), |
| 105 | + xp.float32: MinMax(-3.4028234663852886e38, 3.4028234663852886e38), |
| 106 | + xp.float64: MinMax(-1.7976931348623157e308, 1.7976931348623157e308), |
104 | 107 | }
|
105 | 108 |
|
106 | 109 | dtype_nbits = {
|
@@ -196,67 +199,28 @@ def result_type(*dtypes: DataType):
|
196 | 199 | return result
|
197 | 200 |
|
198 | 201 |
|
199 |
| -func_in_dtypes = { |
200 |
| - # elementwise |
201 |
| - "abs": numeric_dtypes, |
202 |
| - "acos": float_dtypes, |
203 |
| - "acosh": float_dtypes, |
204 |
| - "add": numeric_dtypes, |
205 |
| - "asin": float_dtypes, |
206 |
| - "asinh": float_dtypes, |
207 |
| - "atan": float_dtypes, |
208 |
| - "atan2": float_dtypes, |
209 |
| - "atanh": float_dtypes, |
210 |
| - "bitwise_and": bool_and_all_int_dtypes, |
211 |
| - "bitwise_invert": bool_and_all_int_dtypes, |
212 |
| - "bitwise_left_shift": all_int_dtypes, |
213 |
| - "bitwise_or": bool_and_all_int_dtypes, |
214 |
| - "bitwise_right_shift": all_int_dtypes, |
215 |
| - "bitwise_xor": bool_and_all_int_dtypes, |
216 |
| - "ceil": numeric_dtypes, |
217 |
| - "cos": float_dtypes, |
218 |
| - "cosh": float_dtypes, |
219 |
| - "divide": float_dtypes, |
220 |
| - "equal": all_dtypes, |
221 |
| - "exp": float_dtypes, |
222 |
| - "expm1": float_dtypes, |
223 |
| - "floor": numeric_dtypes, |
224 |
| - "floor_divide": numeric_dtypes, |
225 |
| - "greater": numeric_dtypes, |
226 |
| - "greater_equal": numeric_dtypes, |
227 |
| - "isfinite": numeric_dtypes, |
228 |
| - "isinf": numeric_dtypes, |
229 |
| - "isnan": numeric_dtypes, |
230 |
| - "less": numeric_dtypes, |
231 |
| - "less_equal": numeric_dtypes, |
232 |
| - "log": float_dtypes, |
233 |
| - "logaddexp": float_dtypes, |
234 |
| - "log10": float_dtypes, |
235 |
| - "log1p": float_dtypes, |
236 |
| - "log2": float_dtypes, |
237 |
| - "logical_and": (xp.bool,), |
238 |
| - "logical_not": (xp.bool,), |
239 |
| - "logical_or": (xp.bool,), |
240 |
| - "logical_xor": (xp.bool,), |
241 |
| - "multiply": numeric_dtypes, |
242 |
| - "negative": numeric_dtypes, |
243 |
| - "not_equal": all_dtypes, |
244 |
| - "positive": numeric_dtypes, |
245 |
| - "pow": numeric_dtypes, |
246 |
| - "remainder": numeric_dtypes, |
247 |
| - "round": numeric_dtypes, |
248 |
| - "sign": numeric_dtypes, |
249 |
| - "sin": float_dtypes, |
250 |
| - "sinh": float_dtypes, |
251 |
| - "sqrt": float_dtypes, |
252 |
| - "square": numeric_dtypes, |
253 |
| - "subtract": numeric_dtypes, |
254 |
| - "tan": float_dtypes, |
255 |
| - "tanh": float_dtypes, |
256 |
| - "trunc": numeric_dtypes, |
257 |
| - # searching |
258 |
| - "where": all_dtypes, |
| 202 | +r_in_dtypes = re.compile("x1?: array\n.+Should have an? (.+) data type.") |
| 203 | +r_int_note = re.compile( |
| 204 | + "If one or both of the input arrays have integer data types, " |
| 205 | + "the result is implementation-dependent" |
| 206 | +) |
| 207 | +category_to_dtypes = { |
| 208 | + "boolean": (xp.bool,), |
| 209 | + "integer": all_int_dtypes, |
| 210 | + "floating-point": float_dtypes, |
| 211 | + "numeric": numeric_dtypes, |
| 212 | + "integer or boolean": bool_and_all_int_dtypes, |
259 | 213 | }
|
| 214 | +func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} |
| 215 | +for name, func in name_to_func.items(): |
| 216 | + if m := r_in_dtypes.search(func.__doc__): |
| 217 | + dtype_category = m.group(1) |
| 218 | + if dtype_category == "numeric" and r_int_note.search(func.__doc__): |
| 219 | + dtype_category = "floating-point" |
| 220 | + dtypes = category_to_dtypes[dtype_category] |
| 221 | + func_in_dtypes[name] = dtypes |
| 222 | + elif any("x" in name for name in signature(func).parameters.keys()): |
| 223 | + func_in_dtypes[name] = all_dtypes |
260 | 224 |
|
261 | 225 |
|
262 | 226 | func_returns_bool = {
|
|
0 commit comments