Skip to content

Commit e739239

Browse files
authored
Fix covariance initialization when matrix is not invertible (#277)
* Fix covariance init when matrix is not invertible * replaced import scipy for only required functions * Change inv for pseudo-inv on custom matrix init * Change from EVD to SVD * Roll back to EVD and pseudo inverse of EVD * Fix non-ASCII char * rephrasing warnings * added tests * more rephrasing * fix test * add test * fixes & adds singular pinv test fron eig * fix tolerance of assert * fix tolerance of assert * fix tolerance of assert * fix random seed * isolate random seed setting
1 parent 2380f51 commit e739239

File tree

3 files changed

+151
-15
lines changed

3 files changed

+151
-15
lines changed

metric_learn/_util.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import scipy
32
import six
43
from numpy.linalg import LinAlgError
54
from sklearn.datasets import make_spd_matrix
@@ -8,9 +7,10 @@
87
from sklearn.utils.validation import check_X_y, check_random_state
98
from .exceptions import PreprocessorError, NonPSDError
109
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
11-
from scipy.linalg import pinvh
10+
from scipy.linalg import pinvh, eigh
1211
import sys
1312
import time
13+
import warnings
1414

1515
# hack around lack of axis kwarg in older numpy versions
1616
try:
@@ -678,17 +678,20 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
678678

679679
random_state = check_random_state(random_state)
680680
M = init
681-
if isinstance(init, np.ndarray):
682-
s, u = scipy.linalg.eigh(init)
683-
init_is_definite = _check_sdp_from_eigen(s)
681+
if isinstance(M, np.ndarray):
682+
w, V = eigh(M, check_finite=False)
683+
init_is_definite = _check_sdp_from_eigen(w)
684684
if strict_pd and not init_is_definite:
685685
raise LinAlgError("You should provide a strictly positive definite "
686686
"matrix as `{}`. This one is not definite. Try another"
687687
" {}, or an algorithm that does not "
688688
"require the {} to be strictly positive definite."
689689
.format(*((matrix_name,) * 3)))
690+
elif return_inverse and not init_is_definite:
691+
warnings.warn('The initialization matrix is not invertible: '
692+
'using the pseudo-inverse instead.')
690693
if return_inverse:
691-
M_inv = np.dot(u / s, u.T)
694+
M_inv = _pseudo_inverse_from_eig(w, V)
692695
return M, M_inv
693696
else:
694697
return M
@@ -707,15 +710,23 @@ def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
707710
X = input
708711
# atleast2d is necessary to deal with scalar covariance matrices
709712
M_inv = np.atleast_2d(np.cov(X, rowvar=False))
710-
s, u = scipy.linalg.eigh(M_inv)
711-
cov_is_definite = _check_sdp_from_eigen(s)
713+
w, V = eigh(M_inv, check_finite=False)
714+
cov_is_definite = _check_sdp_from_eigen(w)
712715
if strict_pd and not cov_is_definite:
713716
raise LinAlgError("Unable to get a true inverse of the covariance "
714717
"matrix since it is not definite. Try another "
715718
"`{}`, or an algorithm that does not "
716719
"require the `{}` to be strictly positive definite."
717720
.format(*((matrix_name,) * 2)))
718-
M = np.dot(u / s, u.T)
721+
elif not cov_is_definite:
722+
warnings.warn('The covariance matrix is not invertible: '
723+
'using the pseudo-inverse instead.'
724+
'To make the covariance matrix invertible'
725+
' you can remove any linearly dependent features and/or '
726+
'reduce the dimensionality of your input, '
727+
'for instance using `sklearn.decomposition.PCA` as a '
728+
'preprocessing step.')
729+
M = _pseudo_inverse_from_eig(w, V)
719730
if return_inverse:
720731
return M, M_inv
721732
else:
@@ -742,3 +753,36 @@ def _check_n_components(n_features, n_components):
742753
if 0 < n_components <= n_features:
743754
return n_components
744755
raise ValueError('Invalid n_components, must be in [1, %d]' % n_features)
756+
757+
758+
def _pseudo_inverse_from_eig(w, V, tol=None):
759+
"""Compute the (Moore-Penrose) pseudo-inverse of the EVD of a symetric
760+
matrix.
761+
762+
Parameters
763+
----------
764+
w : (..., M) ndarray
765+
The eigenvalues in ascending order, each repeated according to
766+
its multiplicity.
767+
768+
v : {(..., M, M) ndarray, (..., M, M) matrix}
769+
The column ``v[:, i]`` is the normalized eigenvector corresponding
770+
to the eigenvalue ``w[i]``. Will return a matrix object if `a` is
771+
a matrix object.
772+
773+
tol : positive `float`, optional
774+
Absolute eigenvalues below tol are considered zero.
775+
776+
Returns
777+
-------
778+
output : (..., M, N) array_like
779+
The pseudo-inverse given by the EVD.
780+
"""
781+
if tol is None:
782+
tol = np.amax(w) * np.max(w.shape) * np.finfo(w.dtype).eps
783+
# discard small eigenvalues and invert the rest
784+
large = np.abs(w) > tol
785+
w = np.divide(1, w, where=large, out=w)
786+
w[~large] = 0
787+
788+
return np.dot(V * w, np.conjugate(V).T)

test/test_mahalanobis_mixin.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from scipy.stats import ortho_group
99
from sklearn import clone
1010
from sklearn.cluster import DBSCAN
11-
from sklearn.datasets import make_spd_matrix
12-
from sklearn.utils import check_random_state
11+
from sklearn.datasets import make_spd_matrix, make_blobs
12+
from sklearn.utils import check_random_state, shuffle
1313
from sklearn.utils.multiclass import type_of_target
1414
from sklearn.utils.testing import set_random_state
1515

16-
from metric_learn._util import make_context
16+
from metric_learn._util import make_context, _initialize_metric_mahalanobis
1717
from metric_learn.base_metric import (_QuadrupletsClassifierMixin,
1818
_PairsClassifierMixin)
1919
from metric_learn.exceptions import NonPSDError
@@ -569,7 +569,7 @@ def test_init_mahalanobis(estimator, build_dataset):
569569
in zip(ids_metric_learners,
570570
metric_learners)
571571
if idml[:4] in ['ITML', 'SDML', 'LSML']])
572-
def test_singular_covariance_init_or_prior(estimator, build_dataset):
572+
def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset):
573573
"""Tests that when using the 'covariance' init or prior, it returns the
574574
appropriate error if the covariance matrix is singular, for algorithms
575575
that need a strictly PD prior or init (see
@@ -603,6 +603,48 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset):
603603
assert str(raised_err.value) == msg
604604

605605

606+
@pytest.mark.integration
607+
@pytest.mark.parametrize('estimator, build_dataset',
608+
[(ml, bd) for idml, (ml, bd)
609+
in zip(ids_metric_learners,
610+
metric_learners)
611+
if idml[:3] in ['MMC']],
612+
ids=[idml for idml, (ml, _)
613+
in zip(ids_metric_learners,
614+
metric_learners)
615+
if idml[:3] in ['MMC']])
616+
def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset):
617+
"""Tests that when using the 'covariance' init or prior, it returns the
618+
appropriate warning if the covariance matrix is singular, for algorithms
619+
that don't need a strictly PD init. Also checks that the returned
620+
inverse matrix has finite values
621+
"""
622+
input_data, labels, _, X = build_dataset()
623+
model = clone(estimator)
624+
set_random_state(model)
625+
# We create a feature that is a linear combination of the first two
626+
# features:
627+
input_data = np.concatenate([input_data, input_data[:, ..., :2].dot([[2],
628+
[3]])],
629+
axis=-1)
630+
model.set_params(init='covariance')
631+
msg = ('The covariance matrix is not invertible: '
632+
'using the pseudo-inverse instead.'
633+
'To make the covariance matrix invertible'
634+
' you can remove any linearly dependent features and/or '
635+
'reduce the dimensionality of your input, '
636+
'for instance using `sklearn.decomposition.PCA` as a '
637+
'preprocessing step.')
638+
with pytest.warns(UserWarning) as raised_warning:
639+
model.fit(input_data, labels)
640+
assert np.any([str(warning.message) == msg for warning in raised_warning])
641+
M, _ = _initialize_metric_mahalanobis(X, init='covariance',
642+
random_state=RNG,
643+
return_inverse=True,
644+
strict_pd=False)
645+
assert np.isfinite(M).all()
646+
647+
606648
@pytest.mark.integration
607649
@pytest.mark.parametrize('estimator, build_dataset',
608650
[(ml, bd) for idml, (ml, bd)
@@ -614,7 +656,7 @@ def test_singular_covariance_init_or_prior(estimator, build_dataset):
614656
metric_learners)
615657
if idml[:4] in ['ITML', 'SDML', 'LSML']])
616658
@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
617-
def test_singular_array_init_or_prior(estimator, build_dataset, w0):
659+
def test_singular_array_init_or_prior_strictpd(estimator, build_dataset, w0):
618660
"""Tests that when using a custom array init (or prior), it returns the
619661
appropriate error if it is singular, for algorithms
620662
that need a strictly PD prior or init (see
@@ -654,6 +696,31 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0):
654696
assert str(raised_err.value) == msg
655697

656698

699+
@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
700+
def test_singular_array_init_of_non_strict_pd(w0):
701+
"""Tests that when using a custom array init, it returns the
702+
appropriate warning if it is singular. Also checks if the returned
703+
inverse matrix is finite. This isn't checked for model fitting as no
704+
model curently uses this setting.
705+
"""
706+
rng = np.random.RandomState(42)
707+
X, y = shuffle(*make_blobs(random_state=rng),
708+
random_state=rng)
709+
P = ortho_group.rvs(X.shape[1], random_state=rng)
710+
w = np.abs(rng.randn(X.shape[1]))
711+
w[0] = w0
712+
M = P.dot(np.diag(w)).dot(P.T)
713+
msg = ('The initialization matrix is not invertible: '
714+
'using the pseudo-inverse instead.')
715+
with pytest.warns(UserWarning) as raised_warning:
716+
_, M_inv = _initialize_metric_mahalanobis(X, init=M,
717+
random_state=rng,
718+
return_inverse=True,
719+
strict_pd=False)
720+
assert str(raised_warning[0].message) == msg
721+
assert np.isfinite(M_inv).all()
722+
723+
657724
@pytest.mark.integration
658725
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,
659726
ids=ids_metric_learners)

test/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from scipy.linalg import eigh, pinvh
23
from collections import namedtuple
34
import numpy as np
45
from numpy.testing import assert_array_equal, assert_equal
@@ -11,7 +12,7 @@
1112
check_collapsed_pairs, validate_vector,
1213
_check_sdp_from_eigen, _check_n_components,
1314
check_y_valid_values_for_pairs,
14-
_auto_select_init)
15+
_auto_select_init, _pseudo_inverse_from_eig)
1516
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
1617
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
1718
MMC_Supervised, RCA_Supervised, SDML_Supervised,
@@ -1150,3 +1151,27 @@ def test__auto_select_init(has_classes, n_features, n_samples, n_components,
11501151
"""Checks that the auto selection of the init works as expected"""
11511152
assert (_auto_select_init(has_classes, n_features,
11521153
n_samples, n_components, n_classes) == result)
1154+
1155+
1156+
@pytest.mark.parametrize('w0', [1e-20, 0., -1e-20])
1157+
def test_pseudo_inverse_from_eig_and_pinvh_singular(w0):
1158+
"""Checks that _pseudo_inverse_from_eig returns the same result as
1159+
scipy.linalg.pinvh for a singular matrix"""
1160+
rng = np.random.RandomState(SEED)
1161+
A = rng.rand(100, 100)
1162+
A = A + A.T
1163+
w, V = eigh(A)
1164+
w[0] = w0
1165+
A = V.dot(np.diag(w)).dot(V.T)
1166+
np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A),
1167+
rtol=1e-05)
1168+
1169+
1170+
def test_pseudo_inverse_from_eig_and_pinvh_nonsingular():
1171+
"""Checks that _pseudo_inverse_from_eig returns the same result as
1172+
scipy.linalg.pinvh for a non singular matrix"""
1173+
rng = np.random.RandomState(SEED)
1174+
A = rng.rand(100, 100)
1175+
A = A + A.T
1176+
w, V = eigh(A, check_finite=False)
1177+
np.testing.assert_allclose(_pseudo_inverse_from_eig(w, V), pinvh(A))

0 commit comments

Comments
 (0)