Skip to content

Commit f7f1004

Browse files
committed
Rudimentary test_broadcast_arrays
1 parent e1be518 commit f7f1004

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import hypothesis_helpers as hh
1111
from . import pytest_helpers as ph
1212
from . import xps
13+
from .algos import broadcast_shapes
1314
from .typing import DataType
1415

1516

@@ -56,6 +57,32 @@ def test_astype(x_dtype, dtype, kw, data):
5657
# TODO: test copy
5758

5859

60+
@given(
61+
shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes), data=st.data()
62+
)
63+
def test_broadcast_arrays(shapes, data):
64+
arrays = []
65+
for c, shape in enumerate(shapes, 1):
66+
x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}")
67+
arrays.append(x)
68+
69+
out = xp.broadcast_arrays(*arrays)
70+
71+
out_shape = broadcast_shapes(*shapes)
72+
for i, x in enumerate(arrays):
73+
ph.assert_dtype(
74+
"broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype"
75+
)
76+
ph.assert_result_shape(
77+
"broadcast_arrays",
78+
shapes,
79+
out[i].shape,
80+
out_shape,
81+
repr_name=f"out[{i}].shape",
82+
)
83+
# TODO: test values
84+
85+
5986
def make_dtype_id(dtype: DataType) -> str:
6087
return dh.dtype_to_name[dtype]
6188

0 commit comments

Comments
 (0)