Skip to content

[MRG] Remove shogun dependency #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ package installed).

See the `sphinx documentation`_ for full documentation about installation, API, usage, and examples.

**Notes**

If a recent version of the Shogun Python modular (``modshogun``) library
is available, the LMNN implementation will use the fast C++ version from
there. The two implementations differ slightly, and the C++ version is
more complete.


.. _sphinx documentation: http://metric-learn.github.io/metric-learn/

Expand Down
9 changes: 1 addition & 8 deletions bench/benchmarks/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,9 @@
'NCA': metric_learn.NCA(max_iter=700, n_components=2),
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30,
chunk_size=2),
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500),
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500)
}

try:
from metric_learn.lmnn import python_LMNN
if python_LMNN is not metric_learn.LMNN:
CLASSES['python_LMNN'] = python_LMNN(k=5, learn_rate=1e-6, verbose=False)
except ImportError:
pass


class IrisDataset(object):
params = [sorted(CLASSES)]
Expand Down
8 changes: 0 additions & 8 deletions doc/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ Alternately, download the source repository and run:
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
- For running the examples only: matplotlib

**Notes**

If a recent version of the Shogun Python modular (``modshogun``) library
is available, the LMNN implementation will use the fast C++ version from
there. The two implementations differ slightly, and the C++ version is
more complete.


Quick start
===========

Expand Down
5 changes: 0 additions & 5 deletions doc/supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ indicates :math:`\mathbf{x}_{i}, \mathbf{x}_{j}` belong to different class,
lmnn = LMNN(k=5, learn_rate=1e-6)
lmnn.fit(X, Y, verbose=False)

If a recent version of the Shogun Python modular (``modshogun``) library
is available, the LMNN implementation will use the fast C++ version from
there. Otherwise, the included pure-Python version will be used.
The two implementations differ slightly, and the C++ version is more complete.

.. topic:: References:

.. [1] `Distance Metric Learning for Large Margin Nearest Neighbor
Expand Down
46 changes: 2 additions & 44 deletions metric_learn/lmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from .base_metric import MahalanobisMixin


# commonality between LMNN implementations
class _base_LMNN(MahalanobisMixin, TransformerMixin):
class LMNN(MahalanobisMixin, TransformerMixin):
def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
learn_rate=1e-7, regularization=0.5, convergence_tol=0.001,
use_pca=True, verbose=False, preprocessor=None,
Expand Down Expand Up @@ -114,11 +113,7 @@ def __init__(self, init=None, k=3, min_iter=50, max_iter=1000,
self.n_components = n_components
self.num_dims = num_dims
self.random_state = random_state
super(_base_LMNN, self).__init__(preprocessor)


# slower Python version
class python_LMNN(_base_LMNN):
super(LMNN, self).__init__(preprocessor)

def fit(self, X, y):
if self.num_dims != 'deprecated':
Expand Down Expand Up @@ -344,40 +339,3 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None):
if weights is not None:
return np.dot(Xab.T, Xab * weights[:,None])
return np.dot(Xab.T, Xab)


try:
# use the fast C++ version, if available
from modshogun import LMNN as shogun_LMNN
from modshogun import RealFeatures, MulticlassLabels

class LMNN(_base_LMNN):
"""Large Margin Nearest Neighbor (LMNN)

Attributes
----------
n_iter_ : `int`
The number of iterations the solver has run.

transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
The learned linear transformation ``L``.
"""

def fit(self, X, y):
X, y = self._prepare_inputs(X, y, dtype=float,
ensure_min_samples=2)
labels = MulticlassLabels(y)
self._lmnn = shogun_LMNN(RealFeatures(X.T), labels, self.k)
self._lmnn.set_maxiter(self.max_iter)
self._lmnn.set_obj_threshold(self.convergence_tol)
self._lmnn.set_regularization(self.regularization)
self._lmnn.set_stepsize(self.learn_rate)
if self.use_pca:
self._lmnn.train()
else:
self._lmnn.train(np.eye(X.shape[1]))
self.transformer_ = self._lmnn.get_linear_transform(X)
return self

except ImportError:
LMNN = python_LMNN
18 changes: 8 additions & 10 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
RCA_Supervised, MMC_Supervised, SDML, ITML, LSML)
# Import this specially for testing.
from metric_learn.constraints import wrap_pairs
from metric_learn.lmnn import python_LMNN, _sum_outer_products
from metric_learn.lmnn import _sum_outer_products


def class_separation(X, labels):
Expand Down Expand Up @@ -213,14 +213,12 @@ def test_bounds_parameters_invalid(bounds):

class TestLMNN(MetricTestCase):
def test_iris(self):
# Test both impls, if available.
for LMNN_cls in set((LMNN, python_LMNN)):
lmnn = LMNN_cls(k=5, learn_rate=1e-6, verbose=False)
lmnn.fit(self.iris_points, self.iris_labels)
lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False)
lmnn.fit(self.iris_points, self.iris_labels)

csep = class_separation(lmnn.transform(self.iris_points),
self.iris_labels)
self.assertLess(csep, 0.25)
csep = class_separation(lmnn.transform(self.iris_points),
self.iris_labels)
self.assertLess(csep, 0.25)

def test_loss_grad_lbfgs(self):
"""Test gradient of loss function
Expand Down Expand Up @@ -336,7 +334,7 @@ def test_convergence_simple_example(capsys):
# LMNN should converge on this simple example, which it did not with
# this issue: https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = python_LMNN(verbose=True)
lmnn = LMNN(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
assert "LMNN converged with objective" in out
Expand All @@ -346,7 +344,7 @@ def test_no_twice_same_objective(capsys):
# test that the objective function never has twice the same value
# see https://github.com/metric-learn/metric-learn/issues/88
X, y = make_classification(random_state=0)
lmnn = python_LMNN(verbose=True)
lmnn = LMNN(verbose=True)
lmnn.fit(X, y)
out, _ = capsys.readouterr()
lines = re.split("\n+", out)
Expand Down
15 changes: 8 additions & 7 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ def test_covariance(self):
remove_spaces("Covariance(preprocessor=None)"))

def test_lmnn(self):
self.assertRegexpMatches(
str(metric_learn.LMNN()),
r"(python_)?LMNN\(convergence_tol=0.001, init=None, k=3, "
r"learn_rate=1e-07,\s+"
r"max_iter=1000, min_iter=50, n_components=None, "
r"num_dims='deprecated',\s+preprocessor=None, random_state=None, "
r"regularization=0.5,\s+use_pca=True, verbose=False\)")
self.assertEqual(
remove_spaces(str(metric_learn.LMNN())),
remove_spaces(
"LMNN(convergence_tol=0.001, init=None, k=3, "
"learn_rate=1e-07, "
"max_iter=1000, min_iter=50, n_components=None, "
"num_dims='deprecated', preprocessor=None, random_state=None, "
"regularization=0.5, use_pca=True, verbose=False)"))

def test_nca(self):
self.assertEqual(remove_spaces(str(metric_learn.NCA())),
Expand Down