Skip to content

Commit 09dcd56

Browse files
authored
[MRG] update impostors, closer to original implem (#228)
* first attempt to change the function * Add test and make it work * Import the right scipy * Add test where the number of impostors varies and tests the gradient * fix little pbs * Fix L_next * Fix LMNN * add forgotten L as argument * add forgotten L as argument * fix cost fn call * fix cost fn call * nitpicks * make example work and fix python2 error
1 parent 731b327 commit 09dcd56

File tree

3 files changed

+169
-69
lines changed

3 files changed

+169
-69
lines changed

examples/plot_metric_learning_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def plot_tsne(X, y, colormap=plt.cm.Paired):
139139
#
140140

141141
# setting up LMNN
142-
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6, init='random')
142+
lmnn = metric_learn.LMNN(k=5, learn_rate=1e-6)
143143

144144
# fit the data!
145145
lmnn.fit(X, y)

metric_learn/lmnn.py

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""
22
Large Margin Nearest Neighbor Metric learning (LMNN)
33
"""
4-
# TODO: periodic recalculation of impostors, PCA initialization
5-
64
from __future__ import print_function, absolute_import
75
import numpy as np
86
import warnings
@@ -219,31 +217,19 @@ def fit(self, X, y):
219217
' (smallest class has %d)' % required_k)
220218

221219
target_neighbors = self._select_targets(X, label_inds)
222-
impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds)
223-
if len(impostors) == 0:
224-
# L has already been initialized to an identity matrix
225-
return
226220

227221
# sum outer products
228222
dfG = _sum_outer_products(X, target_neighbors.flatten(),
229223
np.repeat(np.arange(X.shape[0]), k))
230-
df = np.zeros_like(dfG)
231-
232-
# storage
233-
a1 = [None]*k
234-
a2 = [None]*k
235-
for nn_idx in xrange(k):
236-
a1[nn_idx] = np.array([])
237-
a2[nn_idx] = np.array([])
238224

239225
# initialize L
240226
L = self.components_
241227

242228
# first iteration: we compute variables (including objective and gradient)
243229
# at initialization point
244-
G, objective, total_active, df, a1, a2 = (
245-
self._loss_grad(X, L, dfG, impostors, 1, k, reg, target_neighbors, df,
246-
a1, a2))
230+
G, objective, total_active = self._loss_grad(X, L, dfG, k,
231+
reg, target_neighbors,
232+
label_inds)
247233

248234
it = 1 # we already made one iteration
249235

@@ -257,10 +243,9 @@ def fit(self, X, y):
257243
# we compute the objective at next point
258244
# we copy variables that can be modified by _loss_grad, because if we
259245
# retry we don t want to modify them several times
260-
(G_next, objective_next, total_active_next, df_next, a1_next,
261-
a2_next) = (
262-
self._loss_grad(X, L_next, dfG, impostors, it, k, reg,
263-
target_neighbors, df.copy(), list(a1), list(a2)))
246+
(G_next, objective_next, total_active_next) = (
247+
self._loss_grad(X, L_next, dfG, k, reg, target_neighbors,
248+
label_inds))
264249
assert not np.isnan(objective)
265250
delta_obj = objective_next - objective
266251
if delta_obj > 0:
@@ -275,8 +260,7 @@ def fit(self, X, y):
275260
# old variables to these new ones before next iteration and we
276261
# slightly increase the learning rate
277262
L = L_next
278-
G, df, objective, total_active, a1, a2 = (
279-
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
263+
G, objective, total_active = G_next, objective_next, total_active_next
280264
learn_rate *= 1.01
281265

282266
if self.verbose:
@@ -296,54 +280,45 @@ def fit(self, X, y):
296280
self.n_iter_ = it
297281
return self
298282

299-
def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
300-
a1, a2):
283+
def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
301284
# Compute pairwise distances under current metric
302285
Lx = L.dot(X.T).T
303-
g0 = _inplace_paired_L2(*Lx[impostors])
286+
287+
# we need to find the furthest neighbor:
304288
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
289+
furthest_neighbors = np.take_along_axis(target_neighbors,
290+
Ni.argmax(axis=1)[:, None], 1)
291+
impostors = self._find_impostors(furthest_neighbors.ravel(), X,
292+
label_inds, L)
293+
294+
g0 = _inplace_paired_L2(*Lx[impostors])
295+
296+
# we reorder the target neighbors
305297
g1, g2 = Ni[impostors]
306298
# compute the gradient
307299
total_active = 0
308-
for nn_idx in reversed(xrange(k)):
300+
df = np.zeros((X.shape[1], X.shape[1]))
301+
for nn_idx in reversed(xrange(k)): # note: reverse not useful here
309302
act1 = g0 < g1[:, nn_idx]
310303
act2 = g0 < g2[:, nn_idx]
311304
total_active += act1.sum() + act2.sum()
312305

313-
if it > 1:
314-
plus1 = act1 & ~a1[nn_idx]
315-
minus1 = a1[nn_idx] & ~act1
316-
plus2 = act2 & ~a2[nn_idx]
317-
minus2 = a2[nn_idx] & ~act2
318-
else:
319-
plus1 = act1
320-
plus2 = act2
321-
minus1 = np.zeros(0, dtype=int)
322-
minus2 = np.zeros(0, dtype=int)
323-
324306
targets = target_neighbors[:, nn_idx]
325-
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
307+
PLUS, pweight = _count_edges(act1, act2, impostors, targets)
326308
df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight)
327-
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
328-
df -= _sum_outer_products(X, MINUS[:, 0], MINUS[:, 1], mweight)
329309

330310
in_imp, out_imp = impostors
331-
df += _sum_outer_products(X, in_imp[minus1], out_imp[minus1])
332-
df += _sum_outer_products(X, in_imp[minus2], out_imp[minus2])
333-
334-
df -= _sum_outer_products(X, in_imp[plus1], out_imp[plus1])
335-
df -= _sum_outer_products(X, in_imp[plus2], out_imp[plus2])
311+
df -= _sum_outer_products(X, in_imp[act1], out_imp[act1])
312+
df -= _sum_outer_products(X, in_imp[act2], out_imp[act2])
336313

337-
a1[nn_idx] = act1
338-
a2[nn_idx] = act2
339314
# do the gradient update
340315
assert not np.isnan(df).any()
341316
G = dfG * reg + df * (1 - reg)
342317
G = L.dot(G)
343318
# compute the objective function
344319
objective = total_active * (1 - reg)
345320
objective += G.flatten().dot(L.flatten())
346-
return 2 * G, objective, total_active, df, a1, a2
321+
return 2 * G, objective, total_active
347322

348323
def _select_targets(self, X, label_inds):
349324
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)
@@ -355,8 +330,8 @@ def _select_targets(self, X, label_inds):
355330
target_neighbors[inds] = inds[nn]
356331
return target_neighbors
357332

358-
def _find_impostors(self, furthest_neighbors, X, label_inds):
359-
Lx = self.transform(X)
333+
def _find_impostors(self, furthest_neighbors, X, label_inds, L):
334+
Lx = X.dot(L.T)
360335
margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx)
361336
impostors = []
362337
for label in self.labels_[:-1]:

test/metric_learn_test.py

Lines changed: 142 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import re
33
import pytest
44
import numpy as np
5+
import scipy
56
from scipy.optimize import check_grad, approx_fprime
67
from six.moves import xrange
7-
from sklearn.metrics import pairwise_distances
8+
from sklearn.metrics import pairwise_distances, euclidean_distances
89
from sklearn.datasets import (load_iris, make_classification, make_regression,
910
make_spd_matrix)
1011
from numpy.testing import (assert_array_almost_equal, assert_array_equal,
@@ -304,25 +305,15 @@ def test_loss_grad_lbfgs(self):
304305
lmnn.components_ = np.eye(n_components)
305306

306307
target_neighbors = lmnn._select_targets(X, label_inds)
307-
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
308308

309309
# sum outer products
310310
dfG = _sum_outer_products(X, target_neighbors.flatten(),
311311
np.repeat(np.arange(X.shape[0]), k))
312-
df = np.zeros_like(dfG)
313-
314-
# storage
315-
a1 = [None]*k
316-
a2 = [None]*k
317-
for nn_idx in xrange(k):
318-
a1[nn_idx] = np.array([])
319-
a2[nn_idx] = np.array([])
320312

321313
# initialize L
322314
def loss_grad(flat_L):
323-
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors,
324-
1, k, reg, target_neighbors, df.copy(),
325-
list(a1), list(a2))
315+
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG,
316+
k, reg, target_neighbors, label_inds)
326317

327318
def fun(x):
328319
return loss_grad(x)[1]
@@ -366,6 +357,141 @@ def test_deprecation_use_pca(self):
366357
assert_warns_message(DeprecationWarning, msg, lmnn.fit, X, y)
367358

368359

360+
def test_loss_func(capsys):
361+
"""Test the loss function (and its gradient) on a simple example,
362+
by comparing the results with the actual implementation of metric-learn,
363+
with a very simple (but nonperformant) implementation"""
364+
365+
# toy dataset to use
366+
X, y = make_classification(n_samples=10, n_classes=2,
367+
n_features=6,
368+
n_redundant=0, shuffle=True,
369+
scale=[1, 1, 20, 20, 20, 20], random_state=42)
370+
371+
def hinge(a):
372+
if a > 0:
373+
return a, 1
374+
else:
375+
return 0, 0
376+
377+
def loss_fn(L, X, y, target_neighbors, reg):
378+
L = L.reshape(-1, X.shape[1])
379+
Lx = np.dot(X, L.T)
380+
loss = 0
381+
total_active = 0
382+
grad = np.zeros_like(L)
383+
for i in range(X.shape[0]):
384+
for j in target_neighbors[i]:
385+
loss += (1 - reg) * np.sum((Lx[i] - Lx[j]) ** 2)
386+
grad += (1 - reg) * np.outer(Lx[i] - Lx[j], X[i] - X[j])
387+
for l in range(X.shape[0]):
388+
if y[i] != y[l]:
389+
hin, active = hinge(1 + np.sum((Lx[i] - Lx[j])**2) -
390+
np.sum((Lx[i] - Lx[l])**2))
391+
total_active += active
392+
if active:
393+
loss += reg * hin
394+
grad += (reg * (np.outer(Lx[i] - Lx[j], X[i] - X[j]) -
395+
np.outer(Lx[i] - Lx[l], X[i] - X[l])))
396+
grad = 2 * grad
397+
return grad, loss, total_active
398+
399+
# we check that the gradient we have computed in the non-performant implem
400+
# is indeed the true gradient on a toy example:
401+
402+
def _select_targets(X, y, k):
403+
target_neighbors = np.empty((X.shape[0], k), dtype=int)
404+
for label in np.unique(y):
405+
inds, = np.nonzero(y == label)
406+
dd = euclidean_distances(X[inds], squared=True)
407+
np.fill_diagonal(dd, np.inf)
408+
nn = np.argsort(dd)[..., :k]
409+
target_neighbors[inds] = inds[nn]
410+
return target_neighbors
411+
412+
target_neighbors = _select_targets(X, y, 2)
413+
regularization = 0.5
414+
n_features = X.shape[1]
415+
x0 = np.random.randn(1, n_features)
416+
417+
def loss(x0):
418+
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
419+
regularization)[1]
420+
421+
def grad(x0):
422+
return loss_fn(x0.reshape(-1, X.shape[1]), X, y, target_neighbors,
423+
regularization)[0].ravel()
424+
425+
scipy.optimize.check_grad(loss, grad, x0.ravel())
426+
427+
class LMNN_with_callback(LMNN):
428+
""" We will use a callback to get the gradient (see later)
429+
"""
430+
431+
def __init__(self, callback, *args, **kwargs):
432+
self.callback = callback
433+
super(LMNN_with_callback, self).__init__(*args, **kwargs)
434+
435+
def _loss_grad(self, *args, **kwargs):
436+
grad, objective, total_active = (
437+
super(LMNN_with_callback, self)._loss_grad(*args, **kwargs))
438+
self.callback.append(grad)
439+
return grad, objective, total_active
440+
441+
class LMNN_nonperformant(LMNN_with_callback):
442+
443+
def fit(self, X, y):
444+
self.y = y
445+
return super(LMNN_nonperformant, self).fit(X, y)
446+
447+
def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
448+
grad, loss, total_active = loss_fn(L.ravel(), X, self.y,
449+
target_neighbors, self.regularization)
450+
self.callback.append(grad)
451+
return grad, loss, total_active
452+
453+
mem1, mem2 = [], []
454+
lmnn_perf = LMNN_with_callback(verbose=True, random_state=42,
455+
init='identity', max_iter=30, callback=mem1)
456+
lmnn_nonperf = LMNN_nonperformant(verbose=True, random_state=42,
457+
init='identity', max_iter=30,
458+
callback=mem2)
459+
objectives, obj_diffs, learn_rate, total_active = (dict(), dict(), dict(),
460+
dict())
461+
for algo, name in zip([lmnn_perf, lmnn_nonperf], ['perf', 'nonperf']):
462+
algo.fit(X, y)
463+
out, _ = capsys.readouterr()
464+
lines = re.split("\n+", out)
465+
# we get every variable that is printed from the algorithm in verbose
466+
num = '(-?\d+.?\d*(e[+|-]\d+)?)'
467+
strings = [re.search("\d+ (?:{}) (?:{}) (?:(\d+)) (?:{})"
468+
.format(num, num, num), s) for s in lines]
469+
objectives[name] = [float(match.group(1)) for match in strings if match is
470+
not None]
471+
obj_diffs[name] = [float(match.group(3)) for match in strings if match is
472+
not None]
473+
total_active[name] = [float(match.group(5)) for match in strings if
474+
match is not
475+
None]
476+
learn_rate[name] = [float(match.group(6)) for match in strings if match is
477+
not None]
478+
assert len(strings) >= 10 # we ensure that we actually did more than 10
479+
# iterations
480+
assert total_active[name][0] >= 2 # we ensure that we have some active
481+
# constraints (that's the case we want to test)
482+
# we remove the last element because it can be equal to the penultimate
483+
# if the last gradient update is null
484+
for i in range(len(mem1)):
485+
np.testing.assert_allclose(lmnn_perf.callback[i],
486+
lmnn_nonperf.callback[i],
487+
err_msg='Gradient different at position '
488+
'{}'.format(i))
489+
np.testing.assert_allclose(objectives['perf'], objectives['nonperf'])
490+
np.testing.assert_allclose(obj_diffs['perf'], obj_diffs['nonperf'])
491+
np.testing.assert_allclose(total_active['perf'], total_active['nonperf'])
492+
np.testing.assert_allclose(learn_rate['perf'], learn_rate['nonperf'])
493+
494+
369495
@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
370496
[1, 1, 0, 0], 3.0),
371497
(np.array([[0], [1], [2], [3]]),
@@ -386,7 +512,7 @@ def test_toy_ex_lmnn(X, y, loss):
386512
lmnn.components_ = np.eye(n_components)
387513

388514
target_neighbors = lmnn._select_targets(X, label_inds)
389-
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
515+
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds, L)
390516

391517
# sum outer products
392518
dfG = _sum_outer_products(X, target_neighbors.flatten(),
@@ -401,9 +527,8 @@ def test_toy_ex_lmnn(X, y, loss):
401527
a2[nn_idx] = np.array([])
402528

403529
# assert that the loss equals the one computed by hand
404-
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k,
405-
reg, target_neighbors, df, a1, a2)[1] == loss
406-
530+
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, k,
531+
reg, target_neighbors, label_inds)[1] == loss
407532

408533
def test_convergence_simple_example(capsys):
409534
# LMNN should converge on this simple example, which it did not with

0 commit comments

Comments
 (0)