Skip to content

Commit efba316

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] Use pseudo-inverse in Covariance (#206)
* FIX: fix covariance algo * some fixes and add non regression test * Use size instead of len * Address #206 (review)
1 parent fbd92ff commit efba316

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

metric_learn/covariance.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from __future__ import absolute_import
1212
import numpy as np
13+
import scipy
1314
from sklearn.base import TransformerMixin
1415

1516
from .base_metric import MahalanobisMixin
@@ -35,11 +36,11 @@ def fit(self, X, y=None):
3536
y : unused
3637
"""
3738
X = self._prepare_inputs(X, ensure_min_samples=2)
38-
M = np.cov(X, rowvar = False)
39-
if M.ndim == 0:
40-
M = 1./M
39+
M = np.atleast_2d(np.cov(X, rowvar=False))
40+
if M.size == 1:
41+
M = 1. / M
4142
else:
42-
M = np.linalg.inv(M)
43+
M = scipy.linalg.pinvh(M)
4344

4445
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
4546
return self

test/metric_learn_test.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from six.moves import xrange
77
from sklearn.metrics import pairwise_distances
88
from sklearn.datasets import load_iris, make_classification, make_regression
9-
from numpy.testing import assert_array_almost_equal, assert_array_equal
9+
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
10+
assert_allclose)
1011
from sklearn.utils.testing import assert_warns_message
1112
from sklearn.exceptions import ConvergenceWarning
1213
from sklearn.utils.validation import check_X_y
@@ -53,6 +54,23 @@ def test_iris(self):
5354
# deterministic result
5455
self.assertAlmostEqual(csep, 0.72981476)
5556

57+
def test_singular_returns_pseudo_inverse(self):
58+
"""Checks that if the input covariance matrix is singular, we return
59+
the pseudo inverse"""
60+
X, y = load_iris(return_X_y=True)
61+
# We add a virtual column that is a linear combination of the other
62+
# columns so that the covariance matrix will be singular
63+
X = np.concatenate([X, X[:, :2].dot([[2], [3]])], axis=1)
64+
cov_matrix = np.cov(X, rowvar=False)
65+
covariance = Covariance()
66+
covariance.fit(X)
67+
pseudo_inverse = covariance.get_mahalanobis_matrix()
68+
# here is the definition of a pseudo inverse according to wikipedia:
69+
assert_allclose(cov_matrix.dot(pseudo_inverse).dot(cov_matrix),
70+
cov_matrix)
71+
assert_allclose(pseudo_inverse.dot(cov_matrix).dot(pseudo_inverse),
72+
pseudo_inverse)
73+
5674

5775
class TestLSML(MetricTestCase):
5876
def test_iris(self):

0 commit comments

Comments
 (0)