Skip to content

Commit 2ea31c2

Browse files
committed
Use _test_stacks in test_matmul()
1 parent c325be0 commit 2ea31c2

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

array_api_tests/test_linalg.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,14 @@ def test_matmul(x1, x2):
275275
assert res.shape == ()
276276
elif len(x1.shape) == 1:
277277
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
278+
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res, dims=1)
278279
elif len(x2.shape) == 1:
279280
assert res.shape == x1.shape[:-1]
281+
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res, dims=1)
280282
else:
281283
stack_shape = broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
282284
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
285+
_test_stacks(_array_module.linalg.matmul, x1, x2, res=res)
283286

284287
@given(
285288
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)