diff --git a/metric_learn/_util.py b/metric_learn/_util.py index bd57fd5f..7c70e4bf 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,7 @@ +import warnings import numpy as np import six +from numpy.linalg import LinAlgError from sklearn.utils import check_array from sklearn.utils.validation import check_X_y from metric_learn.exceptions import PreprocessorError @@ -324,31 +326,73 @@ def check_collapsed_pairs(pairs): "in total.".format(num_ident, pairs.shape[0])) -def transformer_from_metric(metric): - """Computes the transformation matrix from the Mahalanobis matrix. +def _check_sdp_from_eigen(w, tol=None): + """Checks if some of the eigenvalues given are negative, up to a tolerance + level, with a default value of the tolerance depending on the eigenvalues. - Since by definition the metric `M` is positive semi-definite (PSD), it - admits a Cholesky decomposition: L = cholesky(M).T. However, currently the - computation of the Cholesky decomposition used does not support - non-definite matrices. If the metric is not definite, this method will - return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector - decomposition of M with the eigenvalues in the diagonal matrix w and the - columns of V being the eigenvectors. If M is diagonal, this method will - just return its elementwise square root (since the diagonalization of - the matrix is itself). + Parameters + ---------- + w : array-like, shape=(n_eigenvalues,) + Eigenvalues to check for non semidefinite positiveness. + + tol : positive `float`, optional + Negative eigenvalues above - tol are considered zero. If + tol is None, and eps is the epsilon value for datatype of w, then tol + is set to w.max() * len(w) * eps. + + See Also + -------- + np.linalg.matrix_rank for more details on the choice of tolerance (the same + strategy is applied here) + """ + if tol is None: + tol = w.max() * len(w) * np.finfo(w.dtype).eps + assert tol >= 0, ValueError("tol should be positive.") + if any(w < - tol): + raise ValueError("Matrix is not positive semidefinite (PSD).") + + +def transformer_from_metric(metric, tol=None): + """Returns the transformation matrix from the Mahalanobis matrix. + + Returns the transformation matrix from the Mahalanobis matrix, i.e. the + matrix L such that metric=L.T.dot(L). + + Parameters + ---------- + metric : symmetric `np.ndarray`, shape=(d x d) + The input metric, from which we want to extract a transformation matrix. + + tol : positive `float`, optional + Eigenvalues of `metric` between 0 and - tol are considered zero. If tol is + None, and w_max is `metric`'s largest eigenvalue, and eps is the epsilon + value for datatype of w, then tol is set to w_max * metric.shape[0] * eps. Returns ------- - L : (d x d) matrix + L : np.ndarray, shape=(d x d) + The transformation matrix, such that L.T.dot(L) == metric. """ - - if np.allclose(metric, np.diag(np.diag(metric))): - return np.sqrt(metric) - elif not np.isclose(np.linalg.det(metric), 0): - return np.linalg.cholesky(metric).T + if not np.allclose(metric, metric.T): + raise ValueError("The input metric should be symmetric.") + # If M is diagonal, we will just return the elementwise square root: + if np.array_equal(metric, np.diag(np.diag(metric))): + _check_sdp_from_eigen(np.diag(metric), tol) + return np.diag(np.sqrt(np.maximum(0, np.diag(metric)))) else: - w, V = np.linalg.eigh(metric) - return V.T * np.sqrt(np.maximum(0, w[:, None])) + try: + # if `M` is positive semi-definite, it will admit a Cholesky + # decomposition: L = cholesky(M).T + return np.linalg.cholesky(metric).T + except LinAlgError: + # However, currently np.linalg.cholesky does not support indefinite + # matrices. So if the latter does not work we will return L = V.T w^( + # -1/2), with M = V*w*V.T being the eigenvector decomposition of M with + # the eigenvalues in the diagonal matrix w and the columns of V being the + # eigenvectors. + w, V = np.linalg.eigh(metric) + _check_sdp_from_eigen(w, tol) + return V.T * np.sqrt(np.maximum(0, w[:, None])) def validate_vector(u, dtype=None): diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 6cfe8281..4328320d 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -1,11 +1,16 @@ import unittest import numpy as np +import pytest +from numpy.linalg import LinAlgError +from scipy.stats import ortho_group from sklearn.datasets import load_iris -from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_almost_equal, assert_allclose +from sklearn.utils.testing import ignore_warnings from metric_learn import ( LMNN, NCA, LFDA, Covariance, MLKR, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) +from metric_learn._util import transformer_from_metric class TestTransformerMetricConversion(unittest.TestCase): @@ -76,6 +81,105 @@ def test_mlkr(self): L = mlkr.transformer_ assert_array_almost_equal(L.T.dot(L), mlkr.get_mahalanobis_matrix()) + @ignore_warnings + def test_transformer_from_metric_edge_cases(self): + """Test that transformer_from_metric returns the right result in various + edge cases""" + rng = np.random.RandomState(42) + + # an orthonormal matrix useful for creating matrices with given + # eigenvalues: + P = ortho_group.rvs(7, random_state=rng) + + # matrix with all its coefficients very low (to check that the algorithm + # does not consider it as a diagonal matrix)(non regression test for + # https://github.com/metric-learn/metric-learn/issues/175) + M = np.diag([1e-15, 2e-16, 3e-15, 4e-16, 5e-15, 6e-16, 7e-15]) + M = P.dot(M).dot(P.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # diagonal matrix + M = np.diag(np.abs(rng.randn(5))) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # low-rank matrix (with zeros) + M = np.zeros((7, 7)) + small_random = rng.randn(3, 3) + M[:3, :3] = small_random.T.dot(small_random) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # low-rank matrix (without necessarily zeros) + R = np.abs(rng.randn(7, 7)) + M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # matrix with a determinant still high but which should be considered as a + # non-definite matrix (to check we don't test the definiteness with the + # determinant which is a bad strategy) + M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) + M = P.dot(M).dot(P.T) + assert np.abs(np.linalg.det(M)) > 10 + assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed + # determinant is far from null) + with pytest.raises(LinAlgError) as err_msg: + np.linalg.cholesky(M) + assert str(err_msg.value) == 'Matrix is not positive definite' + # (just to show that this case is indeed considered by numpy as an + # indefinite case) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # matrix with lots of small nonzeros that make a big zero when multiplied + M = np.diag([1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3]) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # full rank matrix + M = rng.randn(10, 10) + M = M.T.dot(M) + assert np.linalg.matrix_rank(M) == 10 + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + def test_non_symmetric_matrix_raises(self): + """Checks that if a non symmetric matrix is given to + transformer_from_metric, an error is thrown""" + rng = np.random.RandomState(42) + M = rng.randn(10, 10) + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(M) + assert str(raised_error.value) == "The input metric should be symmetric." + + def test_non_psd_raises(self): + """Checks that a non PSD matrix (i.e. with negative eigenvalues) will + raise an error when passed to transformer_from_metric""" + rng = np.random.RandomState(42) + D = np.diag([1, 5, 3, 4.2, -4, -2, 1]) + P = ortho_group.rvs(7, random_state=rng) + M = P.dot(D).dot(P.T) + msg = ("Matrix is not positive semidefinite (PSD).") + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(M) + assert str(raised_error.value) == msg + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(D) + assert str(raised_error.value) == msg + + def test_almost_psd_dont_raise(self): + """Checks that if the metric is almost PSD (i.e. it has some negative + eigenvalues very close to zero), then transformer_from_metric will still + work""" + rng = np.random.RandomState(42) + D = np.diag([1, 5, 3, 4.2, -1e-20, -2e-20, -1e-20]) + P = ortho_group.rvs(7, random_state=rng) + M = P.dot(D).dot(P.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + if __name__ == '__main__': unittest.main() diff --git a/test/test_utils.py b/test/test_utils.py index f1df4098..94f025c4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,8 @@ from sklearn.base import clone from metric_learn._util import (check_input, make_context, preprocess_tuples, make_name, preprocess_points, - check_collapsed_pairs, validate_vector) + check_collapsed_pairs, validate_vector, + _check_sdp_from_eigen) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -1030,3 +1031,21 @@ def test__validate_vector(): x = [[1, 2], [3, 4]] with pytest.raises(ValueError): validate_vector(x) + + +def _check_sdp_from_eigen_positive_err_messages(): + """Tests that if _check_sdp_from_eigen is given a negative tol it returns + an error, and if positive it does not""" + w = np.random.RandomState(42).randn(10) + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, -5.) + assert str(raised_error.value) == "tol should be positive." + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, -1e-10) + assert str(raised_error.value) == "tol should be positive." + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, 1.) + assert len(raised_error.value) == 0 + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, 0.) + assert str(raised_error.value) == 0