Skip to content

Commit e1be518

Browse files
committed
Rudimentary test_astype
1 parent cd941c9 commit e1be518

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
8686

8787

8888
class MinMax(NamedTuple):
89-
min: int
90-
max: int
89+
min: Union[int, float]
90+
max: Union[int, float]
9191

9292

9393
dtype_ranges = {
@@ -99,6 +99,8 @@ class MinMax(NamedTuple):
9999
xp.uint16: MinMax(0, +65_535),
100100
xp.uint32: MinMax(0, +4_294_967_295),
101101
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
102+
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
103+
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
102104
}
103105

104106
dtype_nbits = {

array_api_tests/test_data_type_functions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,61 @@
1+
import struct
2+
from typing import Union
3+
14
import pytest
25
from hypothesis import given
6+
from hypothesis import strategies as st
37

48
from . import _array_module as xp
59
from . import dtype_helpers as dh
610
from . import hypothesis_helpers as hh
711
from . import pytest_helpers as ph
12+
from . import xps
813
from .typing import DataType
914

1015

16+
def float32(n: Union[int, float]) -> float:
17+
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
18+
19+
20+
@given(
21+
x_dtype=xps.scalar_dtypes(),
22+
dtype=xps.scalar_dtypes(),
23+
kw=hh.kwargs(copy=st.booleans()),
24+
data=st.data(),
25+
)
26+
def test_astype(x_dtype, dtype, kw, data):
27+
if xp.bool in (x_dtype, dtype):
28+
elements_strat = xps.from_dtype(x_dtype)
29+
else:
30+
m1, M1 = dh.dtype_ranges[x_dtype]
31+
m2, M2 = dh.dtype_ranges[dtype]
32+
if dh.is_int_dtype(x_dtype):
33+
cast = int
34+
elif x_dtype == xp.float32:
35+
cast = float32
36+
else:
37+
cast = float
38+
min_value = cast(max(m1, m2))
39+
max_value = cast(min(M1, M2))
40+
elements_strat = xps.from_dtype(
41+
x_dtype,
42+
min_value=min_value,
43+
max_value=max_value,
44+
allow_nan=False,
45+
allow_infinity=False,
46+
)
47+
x = data.draw(
48+
xps.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
49+
)
50+
51+
out = xp.astype(x, dtype, **kw)
52+
53+
ph.assert_kw_dtype("astype", dtype, out.dtype)
54+
ph.assert_shape("astype", out.shape, x.shape)
55+
# TODO: test values
56+
# TODO: test copy
57+
58+
1159
def make_dtype_id(dtype: DataType) -> str:
1260
return dh.dtype_to_name[dtype]
1361

0 commit comments

Comments
 (0)