Skip to content

Commit d2d77c4

Browse files
author
Björn Barz
committed
Define distance consistently as (x-y)^T*M*(x-y)
Fixes the transformes returned by ITML and LSML. The following now holds also for ITML, LSML, SDML and the covariance method: learner.transformer().T.dot(learner.transformer()) == learner.metric()
1 parent 238be72 commit d2d77c4

File tree

6 files changed

+20
-12
lines changed

6 files changed

+20
-12
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ default implementations for the methods ``metric``, ``transformer``, and
4242
For an instance of a metric learner named ``foo`` learning from a set of
4343
``d``-dimensional points, ``foo.metric()`` returns a ``d x d``
4444
matrix ``M`` such that the distance between vectors ``x`` and ``y`` is
45-
expressed ``sqrt((x-y).dot(inv(M)).dot(x-y))``.
45+
expressed ``sqrt((x-y).dot(M).dot(x-y))``.
4646
Using scipy's ``pdist`` function, this would look like
47-
``pdist(X, metric='mahalanobis', VI=inv(foo.metric()))``.
47+
``pdist(X, metric='mahalanobis', VI=foo.metric())``.
4848

4949
In the same scenario, ``foo.transformer()`` returns a ``d x d``
5050
matrix ``L`` such that a vector ``x`` can be represented in the learned

metric_learn/base_metric.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def metric(self):
2222
def transformer(self):
2323
"""Computes the transformation matrix from the Mahalanobis matrix.
2424
25-
L = inv(cholesky(M))
25+
L = cholesky(M).T
2626
2727
Returns
2828
-------
29-
L : (d x d) matrix
29+
L : upper triangular (d x d) matrix
3030
"""
31-
return inv(cholesky(self.metric()))
31+
return cholesky(self.metric()).T
3232

3333
def transform(self, X=None):
3434
"""Applies the metric transformation.

metric_learn/covariance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,9 @@ def fit(self, X, y=None):
2828
y : unused
2929
"""
3030
self.X_ = check_array(X, ensure_min_samples=2)
31-
self.M_ = np.cov(self.X_.T)
31+
self.M_ = np.cov(self.X_, rowvar = False)
32+
if self.M_.ndim == 0:
33+
self.M_ = 1./self.M_
34+
else:
35+
self.M_ = np.linalg.inv(self.M_)
3236
return self

metric_learn/lsml.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False):
2626
tol : float, optional
2727
max_iter : int, optional
2828
prior : (d x d) matrix, optional
29-
guess at a metric [default: covariance(X)]
29+
guess at a metric [default: inv(covariance(X))]
3030
verbose : bool, optional
3131
if True, prints information while learning
3232
"""
@@ -48,7 +48,11 @@ def _prepare_inputs(self, X, constraints, weights):
4848
self.w_ = weights
4949
self.w_ /= self.w_.sum() # weights must sum to 1
5050
if self.prior is None:
51-
self.M_ = np.cov(X.T)
51+
self.M_ = np.cov(X, rowvar = False)
52+
if self.M_.ndim == 0:
53+
self.M_ = 1./self.M_
54+
else:
55+
self.M_ = np.linalg.inv(self.M_)
5256
else:
5357
self.M_ = self.prior
5458

metric_learn/sdml.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _prepare_inputs(self, X, W):
4747
W = check_array(W, accept_sparse=True)
4848
# set up prior M
4949
if self.use_cov:
50-
self.M_ = np.cov(X.T)
50+
self.M_ = pinvh(np.cov(X, rowvar = False))
5151
else:
5252
self.M_ = np.identity(X.shape[1])
5353
L = laplacian(W, normed=False)
@@ -72,11 +72,11 @@ def fit(self, X, W):
7272
Returns the instance.
7373
"""
7474
loss_matrix = self._prepare_inputs(X, W)
75-
P = pinvh(self.M_) + self.balance_param * loss_matrix
75+
P = self.M_ + self.balance_param * loss_matrix
7676
emp_cov = pinvh(P)
7777
# hack: ensure positive semidefinite
7878
emp_cov = emp_cov.T.dot(emp_cov)
79-
self.M_, _ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
79+
_, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
8080
return self
8181

8282

test/metric_learn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_iris(self):
5757
itml.fit(self.iris_points, self.iris_labels)
5858

5959
csep = class_separation(itml.transform(), self.iris_labels)
60-
self.assertLess(csep, 0.4) # it's not great
60+
self.assertLess(csep, 0.2)
6161

6262

6363
class TestLMNN(MetricTestCase):

0 commit comments

Comments
 (0)