Skip to content

Commit afc3e8c

Browse files
committed
Hard-code array scalar casting input dtypes for dh.func_in_dtypes
1 parent 992900b commit afc3e8c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,13 @@ def result_type(*dtypes: DataType):
360360
func_returns_bool[iop] = func_returns_bool[op]
361361

362362

363+
func_in_dtypes["__bool__"] = (xp.bool,)
364+
func_in_dtypes["__int__"] = all_int_dtypes
365+
func_in_dtypes["__index__"] = all_int_dtypes
366+
func_in_dtypes["__float__"] = float_dtypes
367+
func_in_dtypes["__dlpack__"] = numeric_dtypes
368+
369+
363370
@lru_cache
364371
def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
365372
f_types = []

0 commit comments

Comments
 (0)