Skip to content

Commit 859911b

Browse files
committed
Rudimentary test_matmul
`numpy.array_api.matmul` currently doesn't support broadcastable shapes
1 parent 26bc1d4 commit 859911b

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,11 @@ def matrix_shapes(draw, stack_shapes=shapes()):
136136

137137
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
138138

139-
def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]:
139+
def mutually_broadcastable_shapes(
140+
num_shapes: int, **kw
141+
) -> SearchStrategy[Tuple[Tuple[int, ...], ...]]:
140142
return (
141-
xps.mutually_broadcastable_shapes(num_shapes)
143+
xps.mutually_broadcastable_shapes(num_shapes, **kw)
142144
.map(lambda BS: BS.input_shapes)
143145
.filter(lambda shapes: all(
144146
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes

array_api_tests/test_type_promotion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,18 @@ def test_where(in_dtypes, out_dtype, shapes, data):
205205
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
206206

207207

208+
numeric_promotion_table_params = promotion_table_params[1:]
209+
210+
211+
@pytest.mark.parametrize('in_dtypes, out_dtype', numeric_promotion_table_params)
212+
@given(shapes=hh.mutually_broadcastable_shapes(2, min_dims=1), data=st.data())
213+
def test_matmul(in_dtypes, out_dtype, shapes, data):
214+
x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1')
215+
x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2')
216+
out = xp.matmul(x1, x2)
217+
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
218+
219+
208220
op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = []
209221
op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol}
210222
for op, symbol in op_to_symbol.items():

0 commit comments

Comments
 (0)