Skip to content

Commit fd21315

Browse files
committed
Test can_cast()
1 parent 756ae0c commit fd21315

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,32 @@ def test_broadcast_to(x, data):
9999
# TODO: test values
100100

101101

102+
@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data())
103+
def test_can_cast(_from, to, data):
104+
from_ = data.draw(
105+
st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_"
106+
)
107+
108+
out = xp.can_cast(from_, to)
109+
110+
f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
111+
assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}"
112+
if _from == xp.bool:
113+
expected = to == xp.bool
114+
else:
115+
for dtypes in [dh.all_int_dtypes, dh.float_dtypes]:
116+
if _from in dtypes:
117+
same_family = to in dtypes
118+
break
119+
if same_family:
120+
from_min, from_max = dh.dtype_ranges[_from]
121+
to_min, to_max = dh.dtype_ranges[to]
122+
expected = from_min >= to_min and from_max <= to_max
123+
else:
124+
expected = False
125+
assert out == expected, f"{out=}, but should be {expected} {f_func}"
126+
127+
102128
def make_dtype_id(dtype: DataType) -> str:
103129
return dh.dtype_to_name[dtype]
104130

0 commit comments

Comments
 (0)