From 0f7daf2ac788593807e03a662dca149135fbe90f Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 11 Apr 2019 11:51:41 +0200 Subject: [PATCH 1/5] ENH: make transformer_from_metric more robust --- metric_learn/_util.py | 32 ++++++--- test/test_transformer_metric_conversion.py | 84 +++++++++++++++++++++- 2 files changed, 107 insertions(+), 9 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index bd57fd5f..9d59d171 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -1,5 +1,7 @@ +import warnings import numpy as np import six +from numpy.linalg import LinAlgError from sklearn.utils import check_array from sklearn.utils.validation import check_X_y from metric_learn.exceptions import PreprocessorError @@ -337,18 +339,32 @@ def transformer_from_metric(metric): just return its elementwise square root (since the diagonalization of the matrix is itself). + Parameters + ---------- + metric : symmetric `np.ndarray`, shape=(d x d) + The input metric, from which we want to extract a transformation matrix. + Returns ------- - L : (d x d) matrix + L : np.ndarray, shape=(d x d) + The transformation matrix, such that L.T.dot(L) == metric. """ - - if np.allclose(metric, np.diag(np.diag(metric))): - return np.sqrt(metric) - elif not np.isclose(np.linalg.det(metric), 0): - return np.linalg.cholesky(metric).T + if not np.allclose(metric, metric.T): + raise ValueError("The input metric should be symmetric.") + abs_M = np.abs(metric) + diag_coeffs = np.diag(abs_M) + min_abs_diag_coeff = np.min(diag_coeffs) + if min_abs_diag_coeff >= 1000 * np.max(np.diag(abs_M) - abs_M): + return np.diag(np.sqrt(np.diag(metric))) else: - w, V = np.linalg.eigh(metric) - return V.T * np.sqrt(np.maximum(0, w[:, None])) + try: + return np.linalg.cholesky(metric).T + except LinAlgError as e: + warnings.warn("The Cholesky decomposition returned the following " + "error: '{}'. Using the eigendecomposition " + "instead.".format(e)) + w, V = np.linalg.eigh(metric) + return V.T * np.sqrt(np.maximum(0, w[:, None])) def validate_vector(u, dtype=None): diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 6cfe8281..c695021d 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -1,11 +1,15 @@ import unittest import numpy as np +import pytest +from scipy.stats import ortho_group from sklearn.datasets import load_iris -from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_almost_equal, assert_allclose +from sklearn.utils.testing import ignore_warnings from metric_learn import ( LMNN, NCA, LFDA, Covariance, MLKR, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) +from metric_learn._util import transformer_from_metric class TestTransformerMetricConversion(unittest.TestCase): @@ -76,6 +80,84 @@ def test_mlkr(self): L = mlkr.transformer_ assert_array_almost_equal(L.T.dot(L), mlkr.get_mahalanobis_matrix()) + @ignore_warnings + def test_transformer_from_metric_edge_cases(self): + """Test that transformer_from_metric returns the right result in various + edge cases""" + rng = np.random.RandomState(42) + + # an orthonormal matrix useful for creating matrices with given + # eigenvalues: + P = ortho_group.rvs(7, random_state=rng) + + # matrix with all its coefficients very low (to check that the algorithm + # does not consider it as a diagonal matrix)(non regression test for + # https://github.com/metric-learn/metric-learn/issues/175) + M = np.diag([1e-15, 2e-16, 3e-15, 4e-16, 5e-15, 6e-16, 7e-15]) + M = P.dot(M).dot(P.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # diagonal matrix + M = np.diag(np.abs(rng.randn(5))) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # low-rank matrix (with zeros) + M = np.zeros((7, 7)) + small_random = rng.randn(3, 3) + M[:3, :3] = small_random.T.dot(small_random) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # low-rank matrix (without necessarily zeros) + R = np.abs(rng.randn(7, 7)) + M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # matrix with a determinant still high but which should be considered as a + # non-definite matrix + M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-10]) + M = P.dot(M).dot(P.T) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # matrix with lots of small nonzeros that make a big zero when multiplied + M = np.diag([1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3]) + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + # full rank matrix + M = rng.randn(10, 10) + M = M.T.dot(M) + assert np.linalg.matrix_rank(M) == 10 + L = transformer_from_metric(M) + assert_allclose(L.T.dot(L), M) + + def test_non_symmetric_matrix_raises(self): + """Checks that if a non symmetric matrix is given to + transformer_from_metric, an error is thrown""" + rng = np.random.RandomState(42) + M = rng.randn(10, 10) + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(M) + assert str(raised_error.value) == "The input metric should be symmetric." + + def test_non_psd_warns(self): + """Checks that if the matrix is not PSD it will raise a warning saying + that we will do the eigendecomposition""" + rng = np.random.RandomState(42) + R = np.abs(rng.randn(7, 7)) + M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T) + msg = ("The Cholesky decomposition returned the following " + "error: 'Matrix is not positive definite'. Using the " + "eigendecomposition instead.") + with pytest.warns(None) as raised_warning: + L = transformer_from_metric(M) + assert str(raised_warning[0].message) == msg + assert_allclose(L.T.dot(L), M) + if __name__ == '__main__': unittest.main() From c9eec1f85046624cb9ea05ebe8fa30d2fcfa361b Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 11 Apr 2019 13:30:59 +0200 Subject: [PATCH 2/5] FIX: enhance test on an undefinite matrix with high computed determinant --- test/test_transformer_metric_conversion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index c695021d..6ce31384 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -1,6 +1,7 @@ import unittest import numpy as np import pytest +from numpy.linalg import LinAlgError from scipy.stats import ortho_group from sklearn.datasets import load_iris from numpy.testing import assert_array_almost_equal, assert_allclose @@ -117,9 +118,18 @@ def test_transformer_from_metric_edge_cases(self): assert_allclose(L.T.dot(L), M) # matrix with a determinant still high but which should be considered as a - # non-definite matrix - M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-10]) + # non-definite matrix (to check we don't test the definiteness with the + # determinant which is a bad strategy) + M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) M = P.dot(M).dot(P.T) + assert np.abs(np.linalg.det(M)) > 10 + assert np.linalg.slogdet(M) > 1 # (just to show that the computed + # determinant is far from null) + with pytest.raises(LinAlgError) as err_msg: + np.linalg.cholesky(M) + assert str(err_msg.value) == 'Matrix is not positive definite' + # (just to show that this case is indeed considered by numpy as an + # indefinite case) L = transformer_from_metric(M) assert_allclose(L.T.dot(L), M) From 5700778bd4aec8b9d0904472f8522abaa25c24ba Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 11 Apr 2019 13:58:07 +0200 Subject: [PATCH 3/5] FIX: only look at the value of slogdet, not the sign --- test/test_transformer_metric_conversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 6ce31384..c00ab5f3 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -123,7 +123,7 @@ def test_transformer_from_metric_edge_cases(self): M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) M = P.dot(M).dot(P.T) assert np.abs(np.linalg.det(M)) > 10 - assert np.linalg.slogdet(M) > 1 # (just to show that the computed + assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed # determinant is far from null) with pytest.raises(LinAlgError) as err_msg: np.linalg.cholesky(M) From 7163c331bd8ce412b887411bba8a0ec704fed7cb Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 16 Apr 2019 17:04:49 +0200 Subject: [PATCH 4/5] MAINT: improve transformer_from_metric --- metric_learn/_util.py | 47 +++++++++++++++++----- test/test_transformer_metric_conversion.py | 31 +++++++++----- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 9d59d171..16dfbf37 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -326,7 +326,32 @@ def check_collapsed_pairs(pairs): "in total.".format(num_ident, pairs.shape[0])) -def transformer_from_metric(metric): +def _check_sdp_from_eigen(w, tol=None): + """Checks if some of the eigenvalues given are negative, up to a tolerance + level, with a default value of the tolerance depending on the eigenvalues. + + Parameters + ---------- + w : array-like, shape=(n_eigenvalues,) + Eigenvalues to check for non semidefinite positiveness. + + tol : float, optional + Negative eigenvalues above - tol are considered zero. If + tol is None, and w are `metric`'s eigenvalues, and eps is the + epsilon value for datatype of w, then tol is set to w.max() * len(w) * eps. + + See Also + -------- + np.linalg.matrix_rank for more details on the choice of tolerance (the same + strategy is applied here) + """ + if tol is None: + tol = w.max() * len(w) * np.finfo(w.dtype).eps + if any(w[w < 0] < - tol): + raise ValueError("Matrix is not positive semidefinite (PSD).") + + +def transformer_from_metric(metric, tol=None): """Computes the transformation matrix from the Mahalanobis matrix. Since by definition the metric `M` is positive semi-definite (PSD), it @@ -344,6 +369,12 @@ def transformer_from_metric(metric): metric : symmetric `np.ndarray`, shape=(d x d) The input metric, from which we want to extract a transformation matrix. + tol : positive float, optional + Eigenvalues of `metric` between 0 and - tol are considered zero. If tol is + None, and w are `metric`'s eigenvalues, and eps is the epsilon value for + datatype of w, then tol is set to w.max() * len(w) * eps. + + Returns ------- L : np.ndarray, shape=(d x d) @@ -351,19 +382,15 @@ def transformer_from_metric(metric): """ if not np.allclose(metric, metric.T): raise ValueError("The input metric should be symmetric.") - abs_M = np.abs(metric) - diag_coeffs = np.diag(abs_M) - min_abs_diag_coeff = np.min(diag_coeffs) - if min_abs_diag_coeff >= 1000 * np.max(np.diag(abs_M) - abs_M): - return np.diag(np.sqrt(np.diag(metric))) + if np.array_equal(metric, np.diag(np.diag(metric))): + _check_sdp_from_eigen(np.diag(metric), tol) + return np.diag(np.sqrt(np.maximum(0, np.diag(metric)))) else: try: return np.linalg.cholesky(metric).T - except LinAlgError as e: - warnings.warn("The Cholesky decomposition returned the following " - "error: '{}'. Using the eigendecomposition " - "instead.".format(e)) + except LinAlgError: w, V = np.linalg.eigh(metric) + _check_sdp_from_eigen(w, tol) return V.T * np.sqrt(np.maximum(0, w[:, None])) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index c00ab5f3..80785631 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -154,18 +154,27 @@ def test_non_symmetric_matrix_raises(self): transformer_from_metric(M) assert str(raised_error.value) == "The input metric should be symmetric." - def test_non_psd_warns(self): - """Checks that if the matrix is not PSD it will raise a warning saying - that we will do the eigendecomposition""" + def test_non_psd_raises(self): + """Checks that a non PSD matrix (i.e. with negative eigenvalues) will + raise an error when passed to transformer_from_metric""" rng = np.random.RandomState(42) - R = np.abs(rng.randn(7, 7)) - M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T) - msg = ("The Cholesky decomposition returned the following " - "error: 'Matrix is not positive definite'. Using the " - "eigendecomposition instead.") - with pytest.warns(None) as raised_warning: - L = transformer_from_metric(M) - assert str(raised_warning[0].message) == msg + D = np.diag([1, 5, 3, 4.2, -4, -2, 1]) + P = ortho_group.rvs(7, random_state=rng) + M = P.dot(D).dot(P.T) + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(M) + msg = ("Matrix is not positive semidefinite (PSD).") + assert str(raised_error.value) == msg + + def test_almost_psd_dont_raise(self): + """Checks that if the metric is almost PSD (i.e. it has some negative + eigenvalues very close to zero), then transformer_from_metric will still + work""" + rng = np.random.RandomState(42) + D = np.diag([1, 5, 3, 4.2, -1e-20, -2e-20, -1e-20]) + P = ortho_group.rvs(7, random_state=rng) + M = P.dot(D).dot(P.T) + L = transformer_from_metric(M) assert_allclose(L.T.dot(L), M) From f0088b268e0778e53c6305cf4f90e460723643ba Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Wed, 17 Apr 2019 10:05:14 +0200 Subject: [PATCH 5/5] Address https://github.com/metric-learn/metric-learn/pull/191#pullrequestreview-227267960 --- metric_learn/_util.py | 39 +++++++++++----------- test/test_transformer_metric_conversion.py | 5 ++- test/test_utils.py | 21 +++++++++++- 3 files changed, 44 insertions(+), 21 deletions(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 16dfbf37..7c70e4bf 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -335,10 +335,10 @@ def _check_sdp_from_eigen(w, tol=None): w : array-like, shape=(n_eigenvalues,) Eigenvalues to check for non semidefinite positiveness. - tol : float, optional + tol : positive `float`, optional Negative eigenvalues above - tol are considered zero. If - tol is None, and w are `metric`'s eigenvalues, and eps is the - epsilon value for datatype of w, then tol is set to w.max() * len(w) * eps. + tol is None, and eps is the epsilon value for datatype of w, then tol + is set to w.max() * len(w) * eps. See Also -------- @@ -347,33 +347,26 @@ def _check_sdp_from_eigen(w, tol=None): """ if tol is None: tol = w.max() * len(w) * np.finfo(w.dtype).eps - if any(w[w < 0] < - tol): + assert tol >= 0, ValueError("tol should be positive.") + if any(w < - tol): raise ValueError("Matrix is not positive semidefinite (PSD).") def transformer_from_metric(metric, tol=None): - """Computes the transformation matrix from the Mahalanobis matrix. - - Since by definition the metric `M` is positive semi-definite (PSD), it - admits a Cholesky decomposition: L = cholesky(M).T. However, currently the - computation of the Cholesky decomposition used does not support - non-definite matrices. If the metric is not definite, this method will - return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector - decomposition of M with the eigenvalues in the diagonal matrix w and the - columns of V being the eigenvectors. If M is diagonal, this method will - just return its elementwise square root (since the diagonalization of - the matrix is itself). + """Returns the transformation matrix from the Mahalanobis matrix. + + Returns the transformation matrix from the Mahalanobis matrix, i.e. the + matrix L such that metric=L.T.dot(L). Parameters ---------- metric : symmetric `np.ndarray`, shape=(d x d) The input metric, from which we want to extract a transformation matrix. - tol : positive float, optional + tol : positive `float`, optional Eigenvalues of `metric` between 0 and - tol are considered zero. If tol is - None, and w are `metric`'s eigenvalues, and eps is the epsilon value for - datatype of w, then tol is set to w.max() * len(w) * eps. - + None, and w_max is `metric`'s largest eigenvalue, and eps is the epsilon + value for datatype of w, then tol is set to w_max * metric.shape[0] * eps. Returns ------- @@ -382,13 +375,21 @@ def transformer_from_metric(metric, tol=None): """ if not np.allclose(metric, metric.T): raise ValueError("The input metric should be symmetric.") + # If M is diagonal, we will just return the elementwise square root: if np.array_equal(metric, np.diag(np.diag(metric))): _check_sdp_from_eigen(np.diag(metric), tol) return np.diag(np.sqrt(np.maximum(0, np.diag(metric)))) else: try: + # if `M` is positive semi-definite, it will admit a Cholesky + # decomposition: L = cholesky(M).T return np.linalg.cholesky(metric).T except LinAlgError: + # However, currently np.linalg.cholesky does not support indefinite + # matrices. So if the latter does not work we will return L = V.T w^( + # -1/2), with M = V*w*V.T being the eigenvector decomposition of M with + # the eigenvalues in the diagonal matrix w and the columns of V being the + # eigenvectors. w, V = np.linalg.eigh(metric) _check_sdp_from_eigen(w, tol) return V.T * np.sqrt(np.maximum(0, w[:, None])) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 80785631..4328320d 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -161,9 +161,12 @@ def test_non_psd_raises(self): D = np.diag([1, 5, 3, 4.2, -4, -2, 1]) P = ortho_group.rvs(7, random_state=rng) M = P.dot(D).dot(P.T) + msg = ("Matrix is not positive semidefinite (PSD).") with pytest.raises(ValueError) as raised_error: transformer_from_metric(M) - msg = ("Matrix is not positive semidefinite (PSD).") + assert str(raised_error.value) == msg + with pytest.raises(ValueError) as raised_error: + transformer_from_metric(D) assert str(raised_error.value) == msg def test_almost_psd_dont_raise(self): diff --git a/test/test_utils.py b/test/test_utils.py index f1df4098..94f025c4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,7 +9,8 @@ from sklearn.base import clone from metric_learn._util import (check_input, make_context, preprocess_tuples, make_name, preprocess_points, - check_collapsed_pairs, validate_vector) + check_collapsed_pairs, validate_vector, + _check_sdp_from_eigen) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -1030,3 +1031,21 @@ def test__validate_vector(): x = [[1, 2], [3, 4]] with pytest.raises(ValueError): validate_vector(x) + + +def _check_sdp_from_eigen_positive_err_messages(): + """Tests that if _check_sdp_from_eigen is given a negative tol it returns + an error, and if positive it does not""" + w = np.random.RandomState(42).randn(10) + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, -5.) + assert str(raised_error.value) == "tol should be positive." + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, -1e-10) + assert str(raised_error.value) == "tol should be positive." + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, 1.) + assert len(raised_error.value) == 0 + with pytest.raises(ValueError) as raised_error: + _check_sdp_from_eigen(w, 0.) + assert str(raised_error.value) == 0