Skip to content

Commit f9511a0

Browse files
author
William de Vazelhes
committed
1 parent b238b65 commit f9511a0

File tree

1 file changed

+11
-25
lines changed

1 file changed

+11
-25
lines changed

test/metric_learn_test.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,16 @@ def test_loss_grad_lbfgs(self):
156156
a2[nn_idx] = np.array([])
157157

158158
# initialize L
159+
def loss_grad(flat_L):
160+
return lmnn._loss_grad(X, flat_L.reshape(-1, X.shape[1]), dfG, impostors,
161+
1, k, reg, target_neighbors, df.copy(),
162+
list(a1), list(a2))
159163

160-
def fun(L):
161-
# we copy variables that can be modified by _loss_grad, because we
162-
# want to have the same result when applying the function twice
163-
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors,
164-
1, k, reg, target_neighbors, df.copy(),
165-
list(a1), list(a2))[1]
164+
def fun(x):
165+
loss_grad(x)[1]
166166

167-
def grad(L):
168-
# we copy variables that can be modified by _loss_grad, because we
169-
# want to have the same result when applying the function twice
170-
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors,
171-
1, k, reg, target_neighbors, df.copy(),
172-
list(a1), list(a2))[0].ravel()
167+
def grad(x):
168+
loss_grad(x)[0].ravel()
173169

174170
# compute relative error
175171
epsilon = np.sqrt(np.finfo(float).eps)
@@ -212,19 +208,9 @@ def test_toy_ex_lmnn(X, y, loss):
212208
a1[nn_idx] = np.array([])
213209
a2[nn_idx] = np.array([])
214210

215-
# initialize L
216-
217-
def fun(L):
218-
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
219-
k, reg,
220-
target_neighbors, df, a1, a2)[1]
221-
222-
def grad(L):
223-
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
224-
k, reg, target_neighbors, df, a1, a2)[0].ravel()
225-
226-
# compute relative error
227-
assert fun(L) == loss
211+
# assert that the loss equals the one computed by hand
212+
assert lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1, k,
213+
reg, target_neighbors, df, a1, a2)[1] == loss
228214

229215

230216
def test_convergence_simple_example(capsys):

0 commit comments

Comments
 (0)