Skip to content

Commit 61eea28

Browse files
author
William de Vazelhes
committed
FIX: fix gradient computation
1 parent 86d26d3 commit 61eea28

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

metric_learn/lmnn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def fit(self, X, y):
105105
# objective than the previous L, following the gradient:
106106
while True:
107107
# the next point next_L to try out is found by a gradient step
108-
L_next = L - 2 * learn_rate * G
108+
L_next = L - learn_rate * G
109109
# we compute the objective at next point
110110
# we copy variables that can be modified by _loss_grad, because if we
111111
# retry we don t want to modify them several times
@@ -191,10 +191,12 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df,
191191
# do the gradient update
192192
assert not np.isnan(df).any()
193193
G = dfG * reg + df * (1 - reg)
194+
195+
grad = 2 * L.dot(G)
194196
# compute the objective function
195197
objective = total_active * (1 - reg)
196198
objective += G.flatten().dot(L.T.dot(L).flatten())
197-
return G, objective, total_active, df, a1, a2
199+
return grad, objective, total_active, df, a1, a2
198200

199201
def _select_targets(self, X, label_inds):
200202
target_neighbors = np.empty((X.shape[0], self.k), dtype=int)

test/metric_learn_test.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,24 +158,75 @@ def test_loss_grad_lbfgs(self):
158158
# initialize L
159159

160160
def fun(L):
161-
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
162-
k, reg,
163-
target_neighbors, df, a1, a2)[1]
161+
# we copy variables that can be modified by _loss_grad, because we
162+
# want to have the same result when applying the function twice
163+
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors,
164+
1, k, reg, target_neighbors, df.copy(),
165+
list(a1), list(a2))[1]
164166

165167
def grad(L):
168+
# we copy variables that can be modified by _loss_grad, because we
169+
# want to have the same result when applying the function twice
166170
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors,
167-
1, k, reg,
168-
target_neighbors, df, a1, a2)[0].ravel()
171+
1, k, reg, target_neighbors, df.copy(),
172+
list(a1), list(a2))[0].ravel()
169173

170174
# compute relative error
171175
epsilon = np.sqrt(np.finfo(float).eps)
172176
rel_diff = (check_grad(fun, grad, L.ravel()) /
173-
np.linalg.norm(approx_fprime(L.ravel(), fun,
174-
epsilon)))
175-
# np.linalg.norm(grad(L))
177+
np.linalg.norm(approx_fprime(L.ravel(), fun, epsilon)))
176178
np.testing.assert_almost_equal(rel_diff, 0., decimal=5)
177179

178180

181+
@pytest.mark.parametrize('X, y, loss', [(np.array([[0], [1], [2], [3]]),
182+
[1, 1, 0, 0], 3.0),
183+
(np.array([[0], [1], [2], [3]]),
184+
[1, 0, 0, 1], 26.)])
185+
def test_toy_ex_lmnn(X, y, loss):
186+
"""Test that the loss give the right result on a toy example"""
187+
L = np.array([[1]])
188+
lmnn = LMNN(k=1, regularization=0.5)
189+
190+
k = lmnn.k
191+
reg = lmnn.regularization
192+
193+
X, y = lmnn._prepare_inputs(X, y, dtype=float,
194+
ensure_min_samples=2)
195+
num_pts, num_dims = X.shape
196+
unique_labels, label_inds = np.unique(y, return_inverse=True)
197+
lmnn.labels_ = np.arange(len(unique_labels))
198+
lmnn.transformer_ = np.eye(num_dims)
199+
200+
target_neighbors = lmnn._select_targets(X, label_inds)
201+
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
202+
203+
# sum outer products
204+
dfG = _sum_outer_products(X, target_neighbors.flatten(),
205+
np.repeat(np.arange(X.shape[0]), k))
206+
df = np.zeros_like(dfG)
207+
208+
# storage
209+
a1 = [None]*k
210+
a2 = [None]*k
211+
for nn_idx in xrange(k):
212+
a1[nn_idx] = np.array([])
213+
a2[nn_idx] = np.array([])
214+
215+
# initialize L
216+
217+
def fun(L):
218+
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
219+
k, reg,
220+
target_neighbors, df, a1, a2)[1]
221+
222+
def grad(L):
223+
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
224+
k, reg, target_neighbors, df, a1, a2)[0].ravel()
225+
226+
# compute relative error
227+
assert fun(L) == loss
228+
229+
179230
def test_convergence_simple_example(capsys):
180231
# LMNN should converge on this simple example, which it did not with
181232
# this issue: https://github.com/metric-learn/metric-learn/issues/88

0 commit comments

Comments
 (0)