Skip to content

Commit 86d26d3

Browse files
author
William de Vazelhes
committed
TST: make tests for LMNN gradient
1 parent d945df1 commit 86d26d3

File tree

1 file changed

+61
-4
lines changed

1 file changed

+61
-4
lines changed

test/metric_learn_test.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import pytest
44
import numpy as np
5-
from scipy.optimize import check_grad
5+
from scipy.optimize import check_grad, approx_fprime
66
from six.moves import xrange
77
from sklearn.metrics import pairwise_distances
88
from sklearn.datasets import load_iris, make_classification, make_regression
@@ -21,7 +21,7 @@
2121
RCA_Supervised, MMC_Supervised, SDML)
2222
# Import this specially for testing.
2323
from metric_learn.constraints import wrap_pairs
24-
from metric_learn.lmnn import python_LMNN
24+
from metric_learn.lmnn import python_LMNN, _sum_outer_products
2525

2626

2727
def class_separation(X, labels):
@@ -120,6 +120,61 @@ def test_iris(self):
120120
self.iris_labels)
121121
self.assertLess(csep, 0.25)
122122

123+
def test_loss_grad_lbfgs(self):
124+
"""Test gradient of loss function
125+
Assert that the gradient is almost equal to its finite differences
126+
approximation.
127+
"""
128+
rng = np.random.RandomState(42)
129+
X, y = make_classification(random_state=rng)
130+
L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1])
131+
lmnn = LMNN()
132+
133+
k = lmnn.k
134+
reg = lmnn.regularization
135+
136+
X, y = lmnn._prepare_inputs(X, y, dtype=float,
137+
ensure_min_samples=2)
138+
num_pts, num_dims = X.shape
139+
unique_labels, label_inds = np.unique(y, return_inverse=True)
140+
lmnn.labels_ = np.arange(len(unique_labels))
141+
lmnn.transformer_ = np.eye(num_dims)
142+
143+
target_neighbors = lmnn._select_targets(X, label_inds)
144+
impostors = lmnn._find_impostors(target_neighbors[:, -1], X, label_inds)
145+
146+
# sum outer products
147+
dfG = _sum_outer_products(X, target_neighbors.flatten(),
148+
np.repeat(np.arange(X.shape[0]), k))
149+
df = np.zeros_like(dfG)
150+
151+
# storage
152+
a1 = [None]*k
153+
a2 = [None]*k
154+
for nn_idx in xrange(k):
155+
a1[nn_idx] = np.array([])
156+
a2[nn_idx] = np.array([])
157+
158+
# initialize L
159+
160+
def fun(L):
161+
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors, 1,
162+
k, reg,
163+
target_neighbors, df, a1, a2)[1]
164+
165+
def grad(L):
166+
return lmnn._loss_grad(X, L.reshape(-1, X.shape[1]), dfG, impostors,
167+
1, k, reg,
168+
target_neighbors, df, a1, a2)[0].ravel()
169+
170+
# compute relative error
171+
epsilon = np.sqrt(np.finfo(float).eps)
172+
rel_diff = (check_grad(fun, grad, L.ravel()) /
173+
np.linalg.norm(approx_fprime(L.ravel(), fun,
174+
epsilon)))
175+
# np.linalg.norm(grad(L))
176+
np.testing.assert_almost_equal(rel_diff, 0., decimal=5)
177+
123178

124179
def test_convergence_simple_example(capsys):
125180
# LMNN should converge on this simple example, which it did not with
@@ -421,8 +476,10 @@ def grad(M):
421476
return nca._loss_grad_lbfgs(M, X, mask)[1].ravel()
422477

423478
# compute relative error
424-
rel_diff = check_grad(fun, grad, M.ravel()) / np.linalg.norm(grad(M))
425-
np.testing.assert_almost_equal(rel_diff, 0., decimal=6)
479+
epsilon = np.sqrt(np.finfo(float).eps)
480+
rel_diff = (check_grad(fun, grad, M.ravel()) /
481+
np.linalg.norm(approx_fprime(M.ravel(), fun, epsilon)))
482+
np.testing.assert_almost_equal(rel_diff, 0., decimal=10)
426483

427484
def test_simple_example(self):
428485
"""Test on a simple example.

0 commit comments

Comments
 (0)