Skip to content

Commit df9b5bf

Browse files
author
William de Vazelhes
committed
Stores L and G in addition to what was stored before
1 parent 4c887d7 commit df9b5bf

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

metric_learn/lmnn.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,16 @@ def fit(self, X, y):
9595
L = self.L_
9696
objective = np.inf
9797

98+
# we initialize the roll back
99+
L_old = L.copy()
100+
G_old = G.copy()
101+
df_old = df.copy()
102+
a1_old = [a.copy() for a in a1]
103+
a2_old = [a.copy() for a in a2]
104+
objective_old = objective
105+
98106
# main loop
99107
for it in xrange(1, self.max_iter):
100-
df_old = df.copy()
101-
a1_old = [a.copy() for a in a1]
102-
a2_old = [a.copy() for a in a2]
103-
objective_old = objective
104108
# Compute pairwise distances under current metric
105109
Lx = L.dot(self.X_.T).T
106110
g0 = _inplace_paired_L2(*Lx[impostors])
@@ -158,14 +162,25 @@ def fit(self, X, y):
158162
if delta_obj > 0:
159163
# we're getting worse... roll back!
160164
learn_rate /= 2.0
165+
L = L_old
166+
G = G_old
161167
df = df_old
162168
a1 = a1_old
163169
a2 = a2_old
164170
objective = objective_old
165171
else:
166-
# update L
167-
L -= learn_rate * 2 * L.dot(G)
168-
learn_rate *= 1.01
172+
# We did good. We store this point as reference in case we do
173+
# worse next time.
174+
objective_old = objective
175+
L_old = L.copy()
176+
G_old = G.copy()
177+
df_old = df.copy()
178+
a1_old = [a.copy() for a in a1]
179+
a2_old = [a.copy() for a in a2]
180+
# we update L and will see in the next iteration if it does indeed
181+
# better
182+
L -= learn_rate * 2 * L.dot(G)
183+
learn_rate *= 1.01
169184

170185
# check for convergence
171186
if it > self.min_iter and abs(delta_obj) < self.convergence_tol:
@@ -177,7 +192,7 @@ def fit(self, X, y):
177192
print("LMNN didn't converge in %d steps." % self.max_iter)
178193

179194
# store the last L
180-
self.L_ = L
195+
self.L_ = L_old
181196
self.n_iter_ = it
182197
return self
183198

0 commit comments

Comments
 (0)