Skip to content

Commit 99b0322

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] make transformer_from_metric more robust (#191)
* ENH: make transformer_from_metric more robust * FIX: enhance test on an undefinite matrix with high computed determinant * FIX: only look at the value of slogdet, not the sign * MAINT: improve transformer_from_metric * Address #191 (review)
1 parent edad55d commit 99b0322

File tree

3 files changed

+188
-21
lines changed

3 files changed

+188
-21
lines changed

metric_learn/_util.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import warnings
12
import numpy as np
23
import six
4+
from numpy.linalg import LinAlgError
35
from sklearn.utils import check_array
46
from sklearn.utils.validation import check_X_y
57
from metric_learn.exceptions import PreprocessorError
@@ -324,31 +326,73 @@ def check_collapsed_pairs(pairs):
324326
"in total.".format(num_ident, pairs.shape[0]))
325327

326328

327-
def transformer_from_metric(metric):
328-
"""Computes the transformation matrix from the Mahalanobis matrix.
329+
def _check_sdp_from_eigen(w, tol=None):
330+
"""Checks if some of the eigenvalues given are negative, up to a tolerance
331+
level, with a default value of the tolerance depending on the eigenvalues.
329332
330-
Since by definition the metric `M` is positive semi-definite (PSD), it
331-
admits a Cholesky decomposition: L = cholesky(M).T. However, currently the
332-
computation of the Cholesky decomposition used does not support
333-
non-definite matrices. If the metric is not definite, this method will
334-
return L = V.T w^( -1/2), with M = V*w*V.T being the eigenvector
335-
decomposition of M with the eigenvalues in the diagonal matrix w and the
336-
columns of V being the eigenvectors. If M is diagonal, this method will
337-
just return its elementwise square root (since the diagonalization of
338-
the matrix is itself).
333+
Parameters
334+
----------
335+
w : array-like, shape=(n_eigenvalues,)
336+
Eigenvalues to check for non semidefinite positiveness.
337+
338+
tol : positive `float`, optional
339+
Negative eigenvalues above - tol are considered zero. If
340+
tol is None, and eps is the epsilon value for datatype of w, then tol
341+
is set to w.max() * len(w) * eps.
342+
343+
See Also
344+
--------
345+
np.linalg.matrix_rank for more details on the choice of tolerance (the same
346+
strategy is applied here)
347+
"""
348+
if tol is None:
349+
tol = w.max() * len(w) * np.finfo(w.dtype).eps
350+
assert tol >= 0, ValueError("tol should be positive.")
351+
if any(w < - tol):
352+
raise ValueError("Matrix is not positive semidefinite (PSD).")
353+
354+
355+
def transformer_from_metric(metric, tol=None):
356+
"""Returns the transformation matrix from the Mahalanobis matrix.
357+
358+
Returns the transformation matrix from the Mahalanobis matrix, i.e. the
359+
matrix L such that metric=L.T.dot(L).
360+
361+
Parameters
362+
----------
363+
metric : symmetric `np.ndarray`, shape=(d x d)
364+
The input metric, from which we want to extract a transformation matrix.
365+
366+
tol : positive `float`, optional
367+
Eigenvalues of `metric` between 0 and - tol are considered zero. If tol is
368+
None, and w_max is `metric`'s largest eigenvalue, and eps is the epsilon
369+
value for datatype of w, then tol is set to w_max * metric.shape[0] * eps.
339370
340371
Returns
341372
-------
342-
L : (d x d) matrix
373+
L : np.ndarray, shape=(d x d)
374+
The transformation matrix, such that L.T.dot(L) == metric.
343375
"""
344-
345-
if np.allclose(metric, np.diag(np.diag(metric))):
346-
return np.sqrt(metric)
347-
elif not np.isclose(np.linalg.det(metric), 0):
348-
return np.linalg.cholesky(metric).T
376+
if not np.allclose(metric, metric.T):
377+
raise ValueError("The input metric should be symmetric.")
378+
# If M is diagonal, we will just return the elementwise square root:
379+
if np.array_equal(metric, np.diag(np.diag(metric))):
380+
_check_sdp_from_eigen(np.diag(metric), tol)
381+
return np.diag(np.sqrt(np.maximum(0, np.diag(metric))))
349382
else:
350-
w, V = np.linalg.eigh(metric)
351-
return V.T * np.sqrt(np.maximum(0, w[:, None]))
383+
try:
384+
# if `M` is positive semi-definite, it will admit a Cholesky
385+
# decomposition: L = cholesky(M).T
386+
return np.linalg.cholesky(metric).T
387+
except LinAlgError:
388+
# However, currently np.linalg.cholesky does not support indefinite
389+
# matrices. So if the latter does not work we will return L = V.T w^(
390+
# -1/2), with M = V*w*V.T being the eigenvector decomposition of M with
391+
# the eigenvalues in the diagonal matrix w and the columns of V being the
392+
# eigenvectors.
393+
w, V = np.linalg.eigh(metric)
394+
_check_sdp_from_eigen(w, tol)
395+
return V.T * np.sqrt(np.maximum(0, w[:, None]))
352396

353397

354398
def validate_vector(u, dtype=None):

test/test_transformer_metric_conversion.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import unittest
22
import numpy as np
3+
import pytest
4+
from numpy.linalg import LinAlgError
5+
from scipy.stats import ortho_group
36
from sklearn.datasets import load_iris
4-
from numpy.testing import assert_array_almost_equal
7+
from numpy.testing import assert_array_almost_equal, assert_allclose
8+
from sklearn.utils.testing import ignore_warnings
59

610
from metric_learn import (
711
LMNN, NCA, LFDA, Covariance, MLKR,
812
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
13+
from metric_learn._util import transformer_from_metric
914

1015

1116
class TestTransformerMetricConversion(unittest.TestCase):
@@ -76,6 +81,105 @@ def test_mlkr(self):
7681
L = mlkr.transformer_
7782
assert_array_almost_equal(L.T.dot(L), mlkr.get_mahalanobis_matrix())
7883

84+
@ignore_warnings
85+
def test_transformer_from_metric_edge_cases(self):
86+
"""Test that transformer_from_metric returns the right result in various
87+
edge cases"""
88+
rng = np.random.RandomState(42)
89+
90+
# an orthonormal matrix useful for creating matrices with given
91+
# eigenvalues:
92+
P = ortho_group.rvs(7, random_state=rng)
93+
94+
# matrix with all its coefficients very low (to check that the algorithm
95+
# does not consider it as a diagonal matrix)(non regression test for
96+
# https://github.com/metric-learn/metric-learn/issues/175)
97+
M = np.diag([1e-15, 2e-16, 3e-15, 4e-16, 5e-15, 6e-16, 7e-15])
98+
M = P.dot(M).dot(P.T)
99+
L = transformer_from_metric(M)
100+
assert_allclose(L.T.dot(L), M)
101+
102+
# diagonal matrix
103+
M = np.diag(np.abs(rng.randn(5)))
104+
L = transformer_from_metric(M)
105+
assert_allclose(L.T.dot(L), M)
106+
107+
# low-rank matrix (with zeros)
108+
M = np.zeros((7, 7))
109+
small_random = rng.randn(3, 3)
110+
M[:3, :3] = small_random.T.dot(small_random)
111+
L = transformer_from_metric(M)
112+
assert_allclose(L.T.dot(L), M)
113+
114+
# low-rank matrix (without necessarily zeros)
115+
R = np.abs(rng.randn(7, 7))
116+
M = R.dot(np.diag([1, 5, 3, 2, 0, 0, 0])).dot(R.T)
117+
L = transformer_from_metric(M)
118+
assert_allclose(L.T.dot(L), M)
119+
120+
# matrix with a determinant still high but which should be considered as a
121+
# non-definite matrix (to check we don't test the definiteness with the
122+
# determinant which is a bad strategy)
123+
M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20])
124+
M = P.dot(M).dot(P.T)
125+
assert np.abs(np.linalg.det(M)) > 10
126+
assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed
127+
# determinant is far from null)
128+
with pytest.raises(LinAlgError) as err_msg:
129+
np.linalg.cholesky(M)
130+
assert str(err_msg.value) == 'Matrix is not positive definite'
131+
# (just to show that this case is indeed considered by numpy as an
132+
# indefinite case)
133+
L = transformer_from_metric(M)
134+
assert_allclose(L.T.dot(L), M)
135+
136+
# matrix with lots of small nonzeros that make a big zero when multiplied
137+
M = np.diag([1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3])
138+
L = transformer_from_metric(M)
139+
assert_allclose(L.T.dot(L), M)
140+
141+
# full rank matrix
142+
M = rng.randn(10, 10)
143+
M = M.T.dot(M)
144+
assert np.linalg.matrix_rank(M) == 10
145+
L = transformer_from_metric(M)
146+
assert_allclose(L.T.dot(L), M)
147+
148+
def test_non_symmetric_matrix_raises(self):
149+
"""Checks that if a non symmetric matrix is given to
150+
transformer_from_metric, an error is thrown"""
151+
rng = np.random.RandomState(42)
152+
M = rng.randn(10, 10)
153+
with pytest.raises(ValueError) as raised_error:
154+
transformer_from_metric(M)
155+
assert str(raised_error.value) == "The input metric should be symmetric."
156+
157+
def test_non_psd_raises(self):
158+
"""Checks that a non PSD matrix (i.e. with negative eigenvalues) will
159+
raise an error when passed to transformer_from_metric"""
160+
rng = np.random.RandomState(42)
161+
D = np.diag([1, 5, 3, 4.2, -4, -2, 1])
162+
P = ortho_group.rvs(7, random_state=rng)
163+
M = P.dot(D).dot(P.T)
164+
msg = ("Matrix is not positive semidefinite (PSD).")
165+
with pytest.raises(ValueError) as raised_error:
166+
transformer_from_metric(M)
167+
assert str(raised_error.value) == msg
168+
with pytest.raises(ValueError) as raised_error:
169+
transformer_from_metric(D)
170+
assert str(raised_error.value) == msg
171+
172+
def test_almost_psd_dont_raise(self):
173+
"""Checks that if the metric is almost PSD (i.e. it has some negative
174+
eigenvalues very close to zero), then transformer_from_metric will still
175+
work"""
176+
rng = np.random.RandomState(42)
177+
D = np.diag([1, 5, 3, 4.2, -1e-20, -2e-20, -1e-20])
178+
P = ortho_group.rvs(7, random_state=rng)
179+
M = P.dot(D).dot(P.T)
180+
L = transformer_from_metric(M)
181+
assert_allclose(L.T.dot(L), M)
182+
79183

80184
if __name__ == '__main__':
81185
unittest.main()

test/test_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from sklearn.base import clone
1010
from metric_learn._util import (check_input, make_context, preprocess_tuples,
1111
make_name, preprocess_points,
12-
check_collapsed_pairs, validate_vector)
12+
check_collapsed_pairs, validate_vector,
13+
_check_sdp_from_eigen)
1314
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
1415
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
1516
MMC_Supervised, RCA_Supervised, SDML_Supervised,
@@ -1051,3 +1052,21 @@ def test__validate_vector():
10511052
x = [[1, 2], [3, 4]]
10521053
with pytest.raises(ValueError):
10531054
validate_vector(x)
1055+
1056+
1057+
def _check_sdp_from_eigen_positive_err_messages():
1058+
"""Tests that if _check_sdp_from_eigen is given a negative tol it returns
1059+
an error, and if positive it does not"""
1060+
w = np.random.RandomState(42).randn(10)
1061+
with pytest.raises(ValueError) as raised_error:
1062+
_check_sdp_from_eigen(w, -5.)
1063+
assert str(raised_error.value) == "tol should be positive."
1064+
with pytest.raises(ValueError) as raised_error:
1065+
_check_sdp_from_eigen(w, -1e-10)
1066+
assert str(raised_error.value) == "tol should be positive."
1067+
with pytest.raises(ValueError) as raised_error:
1068+
_check_sdp_from_eigen(w, 1.)
1069+
assert len(raised_error.value) == 0
1070+
with pytest.raises(ValueError) as raised_error:
1071+
_check_sdp_from_eigen(w, 0.)
1072+
assert str(raised_error.value) == 0

0 commit comments

Comments
 (0)