Skip to content

Commit 10a928b

Browse files
committed
Implement test_slogdet
The det == 0 -> logabsdet == -inf check currently fails with the example x = full((5, 5), 6143., dtype=float32) with NumPy. It's not clear if this sort of thing is expected (and the spec should not be taken so literally), or if this is a bug in NumPy.
1 parent 5b44e47 commit 10a928b

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

array_api_tests/test_linalg.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from hypothesis.strategies import booleans, composite, none, tuples, integers, shared
1818

1919
from .array_helpers import (assert_exactly_equal, ndindex, asarray,
20-
numeric_dtype_objects, promote_dtypes)
20+
numeric_dtype_objects, promote_dtypes, equal,
21+
zero, infinity)
2122
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2223
square_matrix_shapes, symmetric_matrices,
2324
positive_definite_matrices, MAX_ARRAY_SIZE,
@@ -378,11 +379,38 @@ def test_qr(x, kw):
378379
pass
379380

380381
@given(
381-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
382+
x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
382383
)
383384
def test_slogdet(x):
384-
# res = linalg.slogdet(x)
385-
pass
385+
res = linalg.slogdet(x)
386+
387+
_test_namedtuple(res, ['sign', 'logabsdet'], 'slotdet')
388+
389+
sign, logabsdet = res
390+
391+
assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype"
392+
assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape"
393+
assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype"
394+
assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape"
395+
396+
397+
_test_stacks(lambda x: linalg.slogdet(x).sign, x,
398+
res=sign, dims=0)
399+
_test_stacks(lambda x: linalg.slogdet(x).logabsdet, x,
400+
res=logabsdet, dims=0)
401+
402+
# Check that when the determinant is 0, the sign and logabsdet are (0,
403+
# -inf).
404+
d = linalg.det(x)
405+
zero_det = equal(d, zero(d.shape, d.dtype))
406+
assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype))
407+
assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype))
408+
409+
# More generally, det(x) should equal sign*exp(logabsdet), but this does
410+
# not hold exactly due to floating-point loss of precision.
411+
412+
# TODO: Test this when we have tests for floating-point values.
413+
# assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
386414

387415
@given(
388416
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)