Skip to content

Commit 187b59e

Browse files
wdevazelhesperimosocordiae
authored andcommitted
[MRG] FIX LMNN gradient and cost function (#201)
* TST: make tests for LMNN gradient * FIX: fix gradient computation * Simplify expression * Be more tolerant for checking NCA * Address #201 (comment) * Add checks for bounds argument * Revert "Add checks for bounds argument" This reverts commit 562f33b. * Add missing return
1 parent f407bac commit 187b59e

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

metric_learn/lmnn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def fit(self, X, y):
108108
# objective than the previous L, following the gradient:
109109
while True:
110110
# the next point next_L to try out is found by a gradient step
111-
L_next = L - 2 * learn_rate * G
111+
L_next = L - learn_rate * G
112112
# we compute the objective at next point
113113
# we copy variables that can be modified by _loss_grad, because if we
114114
# retry we don t want to modify them several times
@@ -194,10 +194,11 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
194194
# do the gradient update
195195
assert not np.isnan(df).any()
196196
G = dfG * reg + df * (1 - reg)
197+
G = L.dot(G)
197198
# compute the objective function
198199
objective = total_active * (1 - reg)
199-
objective += G.flatten().dot(L.T.dot(L).flatten())
200-
return G, objective, total_active, df, a1, a2
200+
objective += G.flatten().dot(L.flatten())
201+
return 2 * G, objective, total_active, df, a1, a2
201202

202203
def _select_targets(self, X, label_inds):
203204
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)

test/metric_learn_test.py

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import pytest
44
import numpy as np
5-
from scipy.optimize import check_grad
5+
from scipy.optimize import check_grad, approx_fprime
66
from six.moves import xrange
77
from sklearn.metrics import pairwise_distances
88
from sklearn.datasets import load_iris, make_classification, make_regression
@@ -21,7 +21,7 @@
2121
RCA_Supervised, MMC_Supervised, SDML, ITML)
2222
# Import this specially for testing.
2323
from metric_learn.constraints import wrap_pairs
24-
from metric_learn.lmnn import python_LMNN
24+
from metric_learn.lmnn import python_LMNN, _sum_outer_products
2525

2626

2727
def class_separation(X, labels):
@@ -157,6 +157,98 @@ def test_iris(self):
157157
self.iris_labels)
158158
self.assertLess(csep, 0.25)
159159

160+
def test_loss_grad_lbfgs(self):
161+
"""Test gradient of loss function
162+
Assert that the gradient is almost equal to its finite differences
163+
approximation.
164+
"""
165+
rng = np.random.RandomState(42)
166+
X, y = make_classification(random_state=rng)
167+
L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1])
168+
lmnn = LMNN()
169+
170+
k = lmnn.k
171+
reg = lmnn.regularization
172+
173+
X, y = lmnn._prepare_inputs(X, y, dtype=float,
174+
ensure_min_samples=2)
175+
num_pts, num_dims = X.shape
176+
unique_labels, label_inds = np.unique(y, return_inverse=True)
177+
lmnn.labels_ = np.arange(len(unique_labels))
178+
lmnn.transformer_ = np.eye(num_dims)
179+
180+
target_neighbors = lmnn._select_targets(X, label_inds)
181+
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
182+
183+
# sum outer products
184+
dfG = _sum_outer_products(X, target_neighbors.flatten(),
185+
np.repeat(np.arange(X.shape[0]), k))
186+
df = np.zeros_like(dfG)
187+
188+
# storage
189+
a1 = [None]*k
190+
a2 = [None]*k
191+
for nn_idx in xrange(k):
192+
a1[nn_idx] = np.array([])
193+
a2[nn_idx] = np.array([])
194+
195+
# initialize L
196+
def loss_grad(flat_L):
197+
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors,
198+
1, k, reg, target_neighbors, df.copy(),
199+
list(a1), list(a2))
200+
201+
def fun(x):
202+
return loss_grad(x)[1]
203+
204+
def grad(x):
205+
return loss_grad(x)[0].ravel()
206+
207+
# compute relative error
208+
epsilon = np.sqrt(np.finfo(float).eps)
209+
rel_diff = (check_grad(fun, grad, L.ravel()) /
210+
np.linalg.norm(approx_fprime(L.ravel(), fun, epsilon)))
211+
np.testing.assert_almost_equal(rel_diff, 0., decimal=5)
212+
213+
214+
@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
215+
[1, 1, 0, 0], 3.0),
216+
(np.array([[0], [1], [2], [3]]),
217+
[1, 0, 0, 1], 26.)])
218+
def test_toy_ex_lmnn(X, y, loss):
219+
"""Test that the loss give the right result on a toy example"""
220+
L = np.array([[1]])
221+
lmnn = LMNN(k=1, regularization=0.5)
222+
223+
k = lmnn.k
224+
reg = lmnn.regularization
225+
226+
X, y = lmnn._prepare_inputs(X, y, dtype=float,
227+
ensure_min_samples=2)
228+
num_pts, num_dims = X.shape
229+
unique_labels, label_inds = np.unique(y, return_inverse=True)
230+
lmnn.labels_ = np.arange(len(unique_labels))
231+
lmnn.transformer_ = np.eye(num_dims)
232+
233+
target_neighbors = lmnn._select_targets(X, label_inds)
234+
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
235+
236+
# sum outer products
237+
dfG = _sum_outer_products(X, target_neighbors.flatten(),
238+
np.repeat(np.arange(X.shape[0]), k))
239+
df = np.zeros_like(dfG)
240+
241+
# storage
242+
a1 = [None]*k
243+
a2 = [None]*k
244+
for nn_idx in xrange(k):
245+
a1[nn_idx] = np.array([])
246+
a2[nn_idx] = np.array([])
247+
248+
# assert that the loss equals the one computed by hand
249+
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k,
250+
reg, target_neighbors, df, a1, a2)[1] == loss
251+
160252

161253
def test_convergence_simple_example(capsys):
162254
# LMNN should converge on this simple example, which it did not with
@@ -458,7 +550,9 @@ def grad(M):
458550
return nca._loss_grad_lbfgs(M, X, mask)[1].ravel()
459551

460552
# compute relative error
461-
rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M))
553+
epsilon = np.sqrt(np.finfo(float).eps)
554+
rel_diff = (check_grad(fun, grad, M.ravel()) /
555+
np.linalg.norm(approx_fprime(M.ravel(), fun, epsilon)))
462556
np.testing.assert_almost_equal(rel_diff, 0., decimal=6)
463557

464558
def test_simple_example(self):

0 commit comments

Comments
 (0)