Skip to content

Commit 544d8ea

Browse files
committed
Implement test_matrix_power()
1 parent 873eeff commit 544d8ea

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

array_api_tests/test_linalg.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,23 @@ def test_matrix_norm(x, kw):
292292
# res = _array_module.linalg.matrix_norm(x, **kw)
293293
pass
294294

295+
matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
295296
@given(
296-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
297-
n=integers(),
297+
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
298+
x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else
299+
xps.arrays(dtype=xps.floating_dtypes(),
300+
shape=square_matrix_shapes)),
301+
n=matrix_power_n,
298302
)
299303
def test_matrix_power(x, n):
300-
# res = _array_module.linalg.matrix_power(x, n)
301-
pass
304+
res = _array_module.linalg.matrix_power(x, n)
305+
if n == 0:
306+
true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype)
307+
else:
308+
true_val = None
309+
# _test_stacks only works with array arguments
310+
func = lambda x: _array_module.linalg.matrix_power(x, n)
311+
_test_stacks(func, x, res=res, true_val=true_val)
302312

303313
@given(
304314
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)