Skip to content

Refactor LMNN as a triplets learner #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metric_learn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .constraints import Constraints
from .covariance import Covariance
from .itml import ITML, ITML_Supervised
from .lmnn import LMNN
from .lmnn import LMNN, LMNN_Supervised
from .lsml import LSML, LSML_Supervised
from .sdml import SDML, SDML_Supervised
from .nca import NCA
Expand All @@ -14,7 +14,7 @@
from ._version import __version__

__all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised',
'LMNN', 'LSML', 'LSML_Supervised', 'SDML',
'LMNN', 'LMNN_Supervised', 'LSML', 'LSML_Supervised', 'SDML',
'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised',
'MLKR', 'MMC', 'MMC_Supervised', 'SCML',
'SCML_Supervised', '__version__']
356 changes: 339 additions & 17 deletions metric_learn/lmnn.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
HAS_SKGGM = False
else:
HAS_SKGGM = True
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
from metric_learn import (LMNN_Supervised, NCA, LFDA, Covariance, MLKR, MMC,
SCML_Supervised, LSML_Supervised,
ITML_Supervised, SDML_Supervised, RCA_Supervised,
MMC_Supervised, SDML, RCA, ITML, SCML)
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_bounds_parameters_invalid(bounds):

class TestLMNN(MetricTestCase):
def test_iris(self):
lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn.fit(self.iris_points, self.iris_labels)

csep = class_separation(lmnn.transform(self.iris_points),
Expand All @@ -396,7 +396,7 @@ def test_loss_grad_lbfgs(self):
rng = np.random.RandomState(42)
X, y = make_classification(random_state=rng)
L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1])
lmnn = LMNN()
lmnn = LMNN_Supervised()

k = lmnn.n_neighbors
reg = lmnn.regularization
Expand Down Expand Up @@ -499,7 +499,7 @@ def grad(x0):

scipy.optimize.check_grad(loss, grad, x0.ravel())

class LMNN_with_callback(LMNN):
class LMNN_with_callback(LMNN_Supervised):
""" We will use a callback to get the gradient (see later)
"""

Expand Down Expand Up @@ -574,7 +574,7 @@ def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds):
def test_toy_ex_lmnn(X, y, loss):
"""Test that the loss give the right result on a toy example"""
L = np.array([[1]])
lmnn = LMNN(n_neighbors=1, regularization=0.5)
lmnn = LMNN_Supervised(n_neighbors=1, regularization=0.5)

k = lmnn.n_neighbors
reg = lmnn.regularization
Expand Down Expand Up @@ -608,7 +608,7 @@ def test_convergence_simple_example(capsys):
# LMNN should converge on this simple example, which it did not with
# this issue: https://github.com/scikit-learn-contrib/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = LMNN(verbose=True)
lmnn = LMNN_Supervised(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
assert "LMNN converged with objective" in out
Expand All @@ -618,7 +618,7 @@ def test_no_twice_same_objective(capsys):
# test that the objective function never has twice the same value
# see https://github.com/scikit-learn-contrib/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = LMNN(verbose=True)
lmnn = LMNN_Supervised(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
lines = re.split("\n+", out)
Expand Down
4 changes: 2 additions & 2 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_lmnn(self):
nndef_kwargs = {'convergence_tol': 0.01, 'n_neighbors': 6}
merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs)
self.assertEqual(
remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01,
remove_spaces(str(metric_learn.LMNN_Supervised(convergence_tol=0.01,
n_neighbors=6))),
remove_spaces(f"LMNN({merged_kwargs})"))
remove_spaces(f"LMNN_Supervised({merged_kwargs})"))

def test_nca(self):
def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None,
Expand Down
4 changes: 2 additions & 2 deletions test/test_components_metric_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from metric_learn.sklearn_shims import ignore_warnings

from metric_learn import (
LMNN, NCA, LFDA, Covariance, MLKR,
LMNN_Supervised, NCA, LFDA, Covariance, MLKR,
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
from metric_learn._util import components_from_metric
from metric_learn.exceptions import NonPSDError
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_itml_supervised(self):
assert_array_almost_equal(L.T.dot(L), itml.get_mahalanobis_matrix())

def test_lmnn(self):
lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn.fit(self.X, self.y)
L = lmnn.components_
assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix())
Expand Down
6 changes: 3 additions & 3 deletions test/test_fit_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numpy.testing import assert_array_almost_equal

from metric_learn import (
LMNN, NCA, LFDA, Covariance, MLKR,
LMNN_Supervised, NCA, LFDA, Covariance, MLKR,
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised,
MMC_Supervised)

Expand Down Expand Up @@ -52,11 +52,11 @@ def test_itml_supervised(self):
assert_array_almost_equal(res_1, res_2)

def test_lmnn(self):
lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn.fit(self.X, self.y)
res_1 = lmnn.transform(self.X)

lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False)
lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False)
res_2 = lmnn.fit_transform(self.X, self.y)

assert_array_almost_equal(res_1, res_2)
Expand Down
2 changes: 1 addition & 1 deletion test/test_mahalanobis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def test_auto_init_transformation(n_samples, n_features, n_classes,
n_components=n_components,
random_state=rng)
# To make the test work for LMNN:
if 'LMNN' in model_base.__class__.__name__:
if 'LMNN_Supervised' in model_base.__class__.__name__:
model_base.set_params(n_neighbors=1)
# To make the test faster for estimators that have a max_iter:
if hasattr(model_base, 'max_iter'):
Expand Down
4 changes: 2 additions & 2 deletions test/test_sklearn_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from metric_learn.sklearn_shims import (assert_allclose_dense_sparse,
set_random_state, _get_args,
is_public_parameter, get_scorer)
from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA,
from metric_learn import (Covariance, LFDA, LMNN_Supervised, MLKR, NCA,
ITML_Supervised, LSML_Supervised,
MMC_Supervised, RCA_Supervised, SDML_Supervised,
SCML_Supervised)
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_covariance(self):
check_estimator(Covariance())

def test_lmnn(self):
check_estimator(LMNN())
check_estimator(LMNN_Supervised())

def test_lfda(self):
check_estimator(LFDA())
Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
check_y_valid_values_for_pairs,
_auto_select_init, _pseudo_inverse_from_eig)
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
LMNN_Supervised, MLKR, NCA, ITML_Supervised, LSML_Supervised,
MMC_Supervised, RCA_Supervised, SDML_Supervised,
SCML, SCML_Supervised, Constraints)
from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin,
Expand Down Expand Up @@ -131,7 +131,7 @@ def build_quadruplets(with_preprocessor=False):

classifiers = [(Covariance(), build_classification),
(LFDA(), build_classification),
(LMNN(), build_classification),
(LMNN_Supervised(), build_classification),
(NCA(), build_classification),
(RCA(), build_classification),
(ITML_Supervised(max_iter=5), build_classification),
Expand Down