@@ -156,20 +156,16 @@ def test_loss_grad_lbfgs(self):
156
156
a2 [nn_idx ] = np .array ([])
157
157
158
158
# initialize L
159
+ def loss_grad (flat_L ):
160
+ return lmnn ._loss_grad (X , flat_L .reshape (- 1 , X .shape [1 ]), dfG , impostors ,
161
+ 1 , k , reg , target_neighbors , df .copy (),
162
+ list (a1 ), list (a2 ))
159
163
160
- def fun (L ):
161
- # we copy variables that can be modified by _loss_grad, because we
162
- # want to have the same result when applying the function twice
163
- return lmnn ._loss_grad (X , L .reshape (- 1 , X .shape [1 ]), dfG , impostors ,
164
- 1 , k , reg , target_neighbors , df .copy (),
165
- list (a1 ), list (a2 ))[1 ]
164
+ def fun (x ):
165
+ loss_grad (x )[1 ]
166
166
167
- def grad (L ):
168
- # we copy variables that can be modified by _loss_grad, because we
169
- # want to have the same result when applying the function twice
170
- return lmnn ._loss_grad (X , L .reshape (- 1 , X .shape [1 ]), dfG , impostors ,
171
- 1 , k , reg , target_neighbors , df .copy (),
172
- list (a1 ), list (a2 ))[0 ].ravel ()
167
+ def grad (x ):
168
+ loss_grad (x )[0 ].ravel ()
173
169
174
170
# compute relative error
175
171
epsilon = np .sqrt (np .finfo (float ).eps )
@@ -212,19 +208,9 @@ def test_toy_ex_lmnn(X, y, loss):
212
208
a1 [nn_idx ] = np .array ([])
213
209
a2 [nn_idx ] = np .array ([])
214
210
215
- # initialize L
216
-
217
- def fun (L ):
218
- return lmnn ._loss_grad (X , L .reshape (- 1 , X .shape [1 ]), dfG , impostors , 1 ,
219
- k , reg ,
220
- target_neighbors , df , a1 , a2 )[1 ]
221
-
222
- def grad (L ):
223
- return lmnn ._loss_grad (X , L .reshape (- 1 , X .shape [1 ]), dfG , impostors , 1 ,
224
- k , reg , target_neighbors , df , a1 , a2 )[0 ].ravel ()
225
-
226
- # compute relative error
227
- assert fun (L ) == loss
211
+ # assert that the loss equals the one computed by hand
212
+ assert lmnn ._loss_grad (X , L .reshape (- 1 , X .shape [1 ]), dfG , impostors , 1 , k ,
213
+ reg , target_neighbors , df , a1 , a2 )[1 ] == loss
228
214
229
215
230
216
def test_convergence_simple_example (capsys ):
0 commit comments