@@ -95,12 +95,16 @@ def fit(self, X, y):
95
95
L = self .L_
96
96
objective = np .inf
97
97
98
+ # we initialize the roll back
99
+ L_old = L .copy ()
100
+ G_old = G .copy ()
101
+ df_old = df .copy ()
102
+ a1_old = [a .copy () for a in a1 ]
103
+ a2_old = [a .copy () for a in a2 ]
104
+ objective_old = objective
105
+
98
106
# main loop
99
107
for it in xrange (1 , self .max_iter ):
100
- df_old = df .copy ()
101
- a1_old = [a .copy () for a in a1 ]
102
- a2_old = [a .copy () for a in a2 ]
103
- objective_old = objective
104
108
# Compute pairwise distances under current metric
105
109
Lx = L .dot (self .X_ .T ).T
106
110
g0 = _inplace_paired_L2 (* Lx [impostors ])
@@ -158,14 +162,25 @@ def fit(self, X, y):
158
162
if delta_obj > 0 :
159
163
# we're getting worse... roll back!
160
164
learn_rate /= 2.0
165
+ L = L_old
166
+ G = G_old
161
167
df = df_old
162
168
a1 = a1_old
163
169
a2 = a2_old
164
170
objective = objective_old
165
171
else :
166
- # update L
167
- L -= learn_rate * 2 * L .dot (G )
168
- learn_rate *= 1.01
172
+ # We did good. We store this point as reference in case we do
173
+ # worse next time.
174
+ objective_old = objective
175
+ L_old = L .copy ()
176
+ G_old = G .copy ()
177
+ df_old = df .copy ()
178
+ a1_old = [a .copy () for a in a1 ]
179
+ a2_old = [a .copy () for a in a2 ]
180
+ # we update L and will see in the next iteration if it does indeed
181
+ # better
182
+ L -= learn_rate * 2 * L .dot (G )
183
+ learn_rate *= 1.01
169
184
170
185
# check for convergence
171
186
if it > self .min_iter and abs (delta_obj ) < self .convergence_tol :
@@ -177,7 +192,7 @@ def fit(self, X, y):
177
192
print ("LMNN didn't converge in %d steps." % self .max_iter )
178
193
179
194
# store the last L
180
- self .L_ = L
195
+ self .L_ = L_old
181
196
self .n_iter_ = it
182
197
return self
183
198
0 commit comments