Skip to content

Commit efeab88

Browse files
wdevazelhesperimosocordiae
authored andcommitted
[MRG] FIX Fix LMNN rollback (#101)
* FIX fixes #88 Stores L and G in addition to what was stored before * TST: non regression test for this PR - test that LMNN converges on a simple example where it should converge - test that the objective function never has twice the same value * MAINT: Invert the order of algorithm: Try forward updates rather than doing rollback after wrong updates * MAINT: update code according to comments #101 (review) * FIX: update also test_convergence_simple_example * FIX: remove \xc2 character * FIX: use list to copy list for python2 compatibility * MAINT: make code more readable with while break (see #101 (comment)) * FIX: remove non ascii character * FIX: remove keyring.deb * STY: remove unused imports
1 parent 7441357 commit efeab88

File tree

2 files changed

+117
-72
lines changed

2 files changed

+117
-72
lines changed

metric_learn/lmnn.py

Lines changed: 85 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -90,83 +90,49 @@ def fit(self, X, y):
9090
a1[nn_idx] = np.array([])
9191
a2[nn_idx] = np.array([])
9292

93-
# initialize gradient and L
94-
G = dfG * reg + df * (1-reg)
93+
# initialize L
9594
L = self.L_
96-
objective = np.inf
97-
98-
# main loop
99-
for it in xrange(1, self.max_iter):
100-
df_old = df.copy()
101-
a1_old = [a.copy() for a in a1]
102-
a2_old = [a.copy() for a in a2]
103-
objective_old = objective
104-
# Compute pairwise distances under current metric
105-
Lx = L.dot(self.X_.T).T
106-
g0 = _inplace_paired_L2(*Lx[impostors])
107-
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:,None,:])
108-
g1,g2 = Ni[impostors]
109-
110-
# compute the gradient
111-
total_active = 0
112-
for nn_idx in reversed(xrange(k)):
113-
act1 = g0 < g1[:,nn_idx]
114-
act2 = g0 < g2[:,nn_idx]
115-
total_active += act1.sum() + act2.sum()
116-
117-
if it > 1:
118-
plus1 = act1 & ~a1[nn_idx]
119-
minus1 = a1[nn_idx] & ~act1
120-
plus2 = act2 & ~a2[nn_idx]
121-
minus2 = a2[nn_idx] & ~act2
122-
else:
123-
plus1 = act1
124-
plus2 = act2
125-
minus1 = np.zeros(0, dtype=int)
126-
minus2 = np.zeros(0, dtype=int)
127-
128-
targets = target_neighbors[:,nn_idx]
129-
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
130-
df += _sum_outer_products(self.X_, PLUS[:,0], PLUS[:,1], pweight)
131-
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
132-
df -= _sum_outer_products(self.X_, MINUS[:,0], MINUS[:,1], mweight)
133-
134-
in_imp, out_imp = impostors
135-
df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1])
136-
df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2])
137-
138-
df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1])
139-
df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2])
140-
141-
a1[nn_idx] = act1
142-
a2[nn_idx] = act2
143-
144-
# do the gradient update
145-
assert not np.isnan(df).any()
146-
G = dfG * reg + df * (1-reg)
14795

148-
# compute the objective function
149-
objective = total_active * (1-reg)
150-
objective += G.flatten().dot(L.T.dot(L).flatten())
151-
assert not np.isnan(objective)
152-
delta_obj = objective - objective_old
96+
# first iteration: we compute variables (including objective and gradient)
97+
# at initialization point
98+
G, objective, total_active, df, a1, a2 = (
99+
self._loss_grad(L, dfG, impostors, 1, k, reg, target_neighbors, df, a1,
100+
a2))
101+
102+
for it in xrange(2, self.max_iter):
103+
# then at each iteration, we try to find a value of L that has better
104+
# objective than the previous L, following the gradient:
105+
while True:
106+
# the next point next_L to try out is found by a gradient step
107+
L_next = L - 2 * learn_rate * G
108+
# we compute the objective at next point
109+
# we copy variables that can be modified by _loss_grad, because if we
110+
# retry we don t want to modify them several times
111+
(G_next, objective_next, total_active_next, df_next, a1_next,
112+
a2_next) = (
113+
self._loss_grad(L_next, dfG, impostors, it, k, reg,
114+
target_neighbors, df.copy(), list(a1), list(a2)))
115+
assert not np.isnan(objective)
116+
delta_obj = objective_next - objective
117+
if delta_obj > 0:
118+
# if we did not find a better objective, we retry with an L closer to
119+
# the starting point, by decreasing the learning rate (making the
120+
# gradient step smaller)
121+
learn_rate /= 2
122+
else:
123+
# otherwise, if we indeed found a better obj, we get out of the loop
124+
break
125+
# when the better L is found (and the related variables), we set the
126+
# old variables to these new ones before next iteration and we
127+
# slightly increase the learning rate
128+
L = L_next
129+
G, df, objective, total_active, a1, a2 = (
130+
G_next, df_next, objective_next, total_active_next, a1_next, a2_next)
131+
learn_rate *= 1.01
153132

154133
if self.verbose:
155134
print(it, objective, delta_obj, total_active, learn_rate)
156135

157-
# update step size
158-
if delta_obj > 0:
159-
# we're getting worse... roll back!
160-
learn_rate /= 2.0
161-
df = df_old
162-
a1 = a1_old
163-
a2 = a2_old
164-
objective = objective_old
165-
else:
166-
# update L
167-
L -= learn_rate * 2 * L.dot(G)
168-
learn_rate *= 1.01
169-
170136
# check for convergence
171137
if it > self.min_iter and abs(delta_obj) < self.convergence_tol:
172138
if self.verbose:
@@ -181,6 +147,54 @@ def fit(self, X, y):
181147
self.n_iter_ = it
182148
return self
183149

150+
def _loss_grad(self, L, dfG, impostors, it, k, reg, target_neighbors, df, a1,
151+
a2):
152+
# Compute pairwise distances under current metric
153+
Lx = L.dot(self.X_.T).T
154+
g0 = _inplace_paired_L2(*Lx[impostors])
155+
Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :])
156+
g1, g2 = Ni[impostors]
157+
# compute the gradient
158+
total_active = 0
159+
for nn_idx in reversed(xrange(k)):
160+
act1 = g0 < g1[:, nn_idx]
161+
act2 = g0 < g2[:, nn_idx]
162+
total_active += act1.sum() + act2.sum()
163+
164+
if it > 1:
165+
plus1 = act1 & ~a1[nn_idx]
166+
minus1 = a1[nn_idx] & ~act1
167+
plus2 = act2 & ~a2[nn_idx]
168+
minus2 = a2[nn_idx] & ~act2
169+
else:
170+
plus1 = act1
171+
plus2 = act2
172+
minus1 = np.zeros(0, dtype=int)
173+
minus2 = np.zeros(0, dtype=int)
174+
175+
targets = target_neighbors[:, nn_idx]
176+
PLUS, pweight = _count_edges(plus1, plus2, impostors, targets)
177+
df += _sum_outer_products(self.X_, PLUS[:, 0], PLUS[:, 1], pweight)
178+
MINUS, mweight = _count_edges(minus1, minus2, impostors, targets)
179+
df -= _sum_outer_products(self.X_, MINUS[:, 0], MINUS[:, 1], mweight)
180+
181+
in_imp, out_imp = impostors
182+
df += _sum_outer_products(self.X_, in_imp[minus1], out_imp[minus1])
183+
df += _sum_outer_products(self.X_, in_imp[minus2], out_imp[minus2])
184+
185+
df -= _sum_outer_products(self.X_, in_imp[plus1], out_imp[plus1])
186+
df -= _sum_outer_products(self.X_, in_imp[plus2], out_imp[plus2])
187+
188+
a1[nn_idx] = act1
189+
a2[nn_idx] = act2
190+
# do the gradient update
191+
assert not np.isnan(df).any()
192+
G = dfG * reg + df * (1 - reg)
193+
# compute the objective function
194+
objective = total_active * (1 - reg)
195+
objective += G.flatten().dot(L.T.dot(L).flatten())
196+
return G, objective, total_active, df, a1, a2
197+
184198
def _select_targets(self):
185199
target_neighbors = np.empty((self.X_.shape[0], self.k), dtype=int)
186200
for label in self.labels_:

test/metric_learn_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import re
21
import unittest
2+
import re
33
import pytest
44
import numpy as np
55
from scipy.optimize import check_grad
@@ -76,6 +76,37 @@ def test_iris(self):
7676
self.assertLess(csep, 0.25)
7777

7878

79+
def test_convergence_simple_example(capsys):
80+
# LMNN should converge on this simple example, which it did not with
81+
# this issue: https://github.com/metric-learn/metric-learn/issues/88
82+
X, y = make_classification(random_state=0)
83+
lmnn = python_LMNN(verbose=True)
84+
lmnn.fit(X, y)
85+
out, _ = capsys.readouterr()
86+
assert "LMNN converged with objective" in out
87+
88+
89+
def test_no_twice_same_objective(capsys):
90+
# test that the objective function never has twice the same value
91+
# see https://github.com/metric-learn/metric-learn/issues/88
92+
X, y = make_classification(random_state=0)
93+
lmnn = python_LMNN(verbose=True)
94+
lmnn.fit(X, y)
95+
out, _ = capsys.readouterr()
96+
lines = re.split("\n+", out)
97+
# we get only objectives from each line:
98+
# the regexp matches a float that follows an integer (the iteration
99+
# number), and which is followed by a (signed) float (delta obj). It
100+
# matches for instance:
101+
# 3 **1113.7665747189938** -3.182774197440267 46431.0200999999999998e-06
102+
objectives = [re.search("\d* (?:(\d*.\d*))[ | -]\d*.\d*", s)
103+
for s in lines]
104+
objectives = [match.group(1) for match in objectives if match is not None]
105+
# we remove the last element because it can be equal to the penultimate
106+
# if the last gradient update is null
107+
assert len(objectives[:-1]) == len(set(objectives[:-1]))
108+
109+
79110
class TestSDML(MetricTestCase):
80111
def test_iris(self):
81112
# Note: this is a flaky test, which fails for certain seeds.

0 commit comments

Comments
 (0)