|
17 | 17 | from hypothesis.strategies import booleans, composite, none, tuples, integers, shared
|
18 | 18 |
|
19 | 19 | from .array_helpers import (assert_exactly_equal, ndindex, asarray,
|
20 |
| - numeric_dtype_objects, promote_dtypes) |
| 20 | + numeric_dtype_objects, promote_dtypes, equal, |
| 21 | + zero, infinity) |
21 | 22 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
22 | 23 | square_matrix_shapes, symmetric_matrices,
|
23 | 24 | positive_definite_matrices, MAX_ARRAY_SIZE,
|
@@ -378,11 +379,38 @@ def test_qr(x, kw):
|
378 | 379 | pass
|
379 | 380 |
|
380 | 381 | @given(
|
381 |
| - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
| 382 | + x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes), |
382 | 383 | )
|
383 | 384 | def test_slogdet(x):
|
384 |
| - # res = linalg.slogdet(x) |
385 |
| - pass |
| 385 | + res = linalg.slogdet(x) |
| 386 | + |
| 387 | + _test_namedtuple(res, ['sign', 'logabsdet'], 'slotdet') |
| 388 | + |
| 389 | + sign, logabsdet = res |
| 390 | + |
| 391 | + assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype" |
| 392 | + assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape" |
| 393 | + assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype" |
| 394 | + assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape" |
| 395 | + |
| 396 | + |
| 397 | + _test_stacks(lambda x: linalg.slogdet(x).sign, x, |
| 398 | + res=sign, dims=0) |
| 399 | + _test_stacks(lambda x: linalg.slogdet(x).logabsdet, x, |
| 400 | + res=logabsdet, dims=0) |
| 401 | + |
| 402 | + # Check that when the determinant is 0, the sign and logabsdet are (0, |
| 403 | + # -inf). |
| 404 | + d = linalg.det(x) |
| 405 | + zero_det = equal(d, zero(d.shape, d.dtype)) |
| 406 | + assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) |
| 407 | + assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) |
| 408 | + |
| 409 | + # More generally, det(x) should equal sign*exp(logabsdet), but this does |
| 410 | + # not hold exactly due to floating-point loss of precision. |
| 411 | + |
| 412 | + # TODO: Test this when we have tests for floating-point values. |
| 413 | + # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps) |
386 | 414 |
|
387 | 415 | @given(
|
388 | 416 | x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
|
|
0 commit comments