Skip to content

Commit daafc40

Browse files
committed
Implement test_matrix_transpose
1 parent d390b9e commit daafc40

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

array_api_tests/test_linalg.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,21 @@ def test_matrix_rank(x, kw):
322322
pass
323323

324324
@given(
325-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
325+
x=xps.arrays(dtype=dtypes, shape=matrix_shapes),
326326
)
327327
def test_matrix_transpose(x):
328-
# res = linalg.matrix_transpose(x)
329-
pass
328+
res = linalg.matrix_transpose(x)
329+
true_val = lambda a: _array_module.asarray([[a[i, j] for i in
330+
range(a.shape[0])] for j in
331+
range(a.shape[1])],
332+
dtype=a.dtype)
333+
shape = list(x.shape)
334+
shape[-1], shape[-2] = shape[-2], shape[-1]
335+
shape = tuple(shape)
336+
assert res.shape == shape, "matrix_transpose() did not return the correct shape"
337+
assert res.dtype == x.dtype, "matrix_transpose() did not return the correct dtype"
338+
339+
_test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
330340

331341
@given(
332342
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)