Skip to content

Commit 756ae0c

Browse files
committed
Rudimentary test_broadcast_to
1 parent f7f1004 commit 756ae0c

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

array_api_tests/test_data_type_functions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ def test_broadcast_arrays(shapes, data):
8383
# TODO: test values
8484

8585

86+
@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
87+
def test_broadcast_to(x, data):
88+
shape = data.draw(
89+
hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
90+
.map(lambda S: S[0])
91+
.filter(lambda s: broadcast_shapes(x.shape, s) == s),
92+
label="shape",
93+
)
94+
95+
out = xp.broadcast_to(x, shape)
96+
97+
ph.assert_dtype("broadcast_to", x.dtype, out.dtype)
98+
ph.assert_shape("broadcast_to", out.shape, shape)
99+
# TODO: test values
100+
101+
86102
def make_dtype_id(dtype: DataType) -> str:
87103
return dh.dtype_to_name[dtype]
88104

0 commit comments

Comments
 (0)