Skip to content

Commit f3c690e

Browse files
author
William de Vazelhes
committed
Remove the need for skggm
1 parent 4b0bae9 commit f3c690e

9 files changed

+41
-104
lines changed

metric_learn/_util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@ def vector_norm(X):
1515
return np.linalg.norm(X, axis=1)
1616

1717

18-
def has_installed_skggm():
19-
try:
20-
import inverse_covariance
21-
return True
22-
except ImportError:
23-
return False
24-
25-
2618
def check_input(input_data, y=None, preprocessor=None,
2719
type_of_inputs='classic', tuple_size=None, accept_sparse=False,
2820
dtype='numeric', order=None,

metric_learn/sdml.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
import numpy as np
1414
from sklearn.base import TransformerMixin
1515
from scipy.linalg import pinvh
16+
from sklearn.covariance import graphical_lasso
1617
from sklearn.exceptions import ConvergenceWarning
1718

1819
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
1920
from .constraints import Constraints, wrap_pairs
20-
from ._util import transformer_from_metric, has_installed_skggm
21-
if has_installed_skggm():
22-
from inverse_covariance import quic
21+
from ._util import transformer_from_metric
2322

2423

2524
class _BaseSDML(MahalanobisMixin):
@@ -47,11 +46,6 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
4746
The preprocessor to call to get tuples from indices. If array-like,
4847
tuples will be gotten like this: X[indices].
4948
"""
50-
if not has_installed_skggm():
51-
raise NotImplementedError("SDML cannot be instantiated without "
52-
"installing skggm. Please install skggm and "
53-
"try again (make sure you meet skggm's "
54-
"requirements).")
5549
self.balance_param = balance_param
5650
self.sparsity_param = sparsity_param
5751
self.use_cov = use_cov
@@ -83,11 +77,10 @@ def _fit(self, pairs, y):
8377
"To prevent that, try to decrease the balance parameter "
8478
"`balance_param` and/or to set use_covariance=False.",
8579
ConvergenceWarning)
86-
sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
87-
theta0 = pinvh(sigma0)
88-
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
89-
msg=self.verbose,
90-
Theta0=theta0, Sigma0=sigma0)
80+
cov_init = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
81+
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
82+
verbose=self.verbose,
83+
cov_init=cov_init)
9184
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
9285
return self
9386

test/metric_learn_test.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
LSML_Supervised, ITML_Supervised, SDML_Supervised,
1616
RCA_Supervised, MMC_Supervised, SDML)
1717
# Import this specially for testing.
18-
from metric_learn._util import has_installed_skggm
1918
from metric_learn.constraints import wrap_pairs
2019
from metric_learn.lmnn import python_LMNN
2120

@@ -150,43 +149,25 @@ def test_no_twice_same_objective(capsys):
150149

151150
class TestSDML(MetricTestCase):
152151

153-
def test_raises_error_msg_not_installed_skggm(self):
154-
"""Tests that the right error message is raised if someone tries to
155-
instantiate SDML but has not installed skggm"""
156-
# TODO: to be removed when scikit-learn v0.21 is released
157-
if not has_installed_skggm():
158-
msg = ("SDML cannot be instantiated without "
159-
"installing skggm. Please install skggm and "
160-
"try again (make sure you meet skggm's "
161-
"requirements).")
162-
with pytest.raises(NotImplementedError) as expected_err:
163-
SDML()
164-
assert str(expected_err.value) == msg
165-
else: # otherwise we should be able to instantiate SDML and it should
166-
# raise no warning
167-
with pytest.warns(None) as record:
168-
SDML()
169-
assert len(record) == 0
170-
171-
if has_installed_skggm():
172-
173152
def test_iris(self):
174153
# Note: this is a flaky test, which fails for certain seeds.
175154
# TODO: un-flake it!
176155
rs = np.random.RandomState(5555)
177156

178-
sdml = SDML_Supervised(num_constraints=1500)
157+
sdml = SDML_Supervised(num_constraints=1500, use_cov=False,
158+
balance_param=5e-5)
179159
sdml.fit(self.iris_points, self.iris_labels, random_state=rs)
180160
csep = class_separation(sdml.transform(self.iris_points),
181161
self.iris_labels)
182-
self.assertLess(csep, 0.20)
162+
self.assertLess(csep, 0.22)
183163

184164
def test_deprecation_num_labeled(self):
185165
# test that a deprecation message is thrown if num_labeled is set at
186166
# initialization
187167
# TODO: remove in v.0.6
188168
X, y = make_classification()
189-
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
169+
sdml_supervised = SDML_Supervised(num_labeled=np.inf, use_cov=False,
170+
balance_param=5e-5)
190171
msg = ('"num_labeled" parameter is not used.'
191172
' It has been deprecated in version 0.5.0 and will be'
192173
'removed in 0.6.0')

test/test_base_metric.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sklearn import clone
66
from sklearn.utils.testing import set_random_state
77

8-
from metric_learn._util import has_installed_skggm
98
from test.test_utils import ids_metric_learners, metric_learners
109

1110

@@ -54,33 +53,17 @@ def test_lsml(self):
5453
weights=None)
5554
""".strip('\n'))
5655

57-
if has_installed_skggm():
58-
def test_sdml(self):
59-
self.assertEqual(str(metric_learn.SDML()),
60-
"SDML(balance_param=0.5, preprocessor=None, "
61-
"sparsity_param=0.01, use_cov=True,\n "
62-
"verbose=False)")
63-
self.assertEqual(str(metric_learn.SDML_Supervised()), """
56+
def test_sdml(self):
57+
self.assertEqual(str(metric_learn.SDML()),
58+
"SDML(balance_param=0.5, preprocessor=None, "
59+
"sparsity_param=0.01, use_cov=True,\n "
60+
"verbose=False)")
61+
self.assertEqual(str(metric_learn.SDML_Supervised()), """
6462
SDML_Supervised(balance_param=0.5, num_constraints=None,
6563
num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
6664
use_cov=True, verbose=False)
6765
""".strip('\n'))
6866

69-
else:
70-
# if we haven't install skggm, we will just check that instantiating SDML
71-
# or SDML_Supervised will return an error
72-
def test_sdml(self):
73-
expected_msg = ("SDML cannot be instantiated without "
74-
"installing skggm. Please install skggm and "
75-
"try again (make sure you meet skggm's "
76-
"requirements).")
77-
with pytest.raises(NotImplementedError) as raised_error:
78-
metric_learn.SDML()
79-
assert str(raised_error.value) == expected_msg
80-
with pytest.raises(NotImplementedError) as raised_error:
81-
metric_learn.SDML_Supervised()
82-
assert str(raised_error.value) == expected_msg
83-
8467
def test_rca(self):
8568
self.assertEqual(str(metric_learn.RCA()),
8669
"RCA(num_dims=None, pca_comps=None, preprocessor=None)")

test/test_fit_transform.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66

77
from metric_learn import (
88
LMNN, NCA, LFDA, Covariance, MLKR,
9-
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised)
10-
11-
from metric_learn._util import has_installed_skggm
9+
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised,
10+
MMC_Supervised)
1211

1312

1413
class TestFitTransform(unittest.TestCase):
@@ -63,18 +62,19 @@ def test_lmnn(self):
6362

6463
assert_array_almost_equal(res_1, res_2)
6564

66-
if has_installed_skggm():
67-
def test_sdml_supervised(self):
68-
seed = np.random.RandomState(1234)
69-
sdml = SDML_Supervised(num_constraints=1500)
70-
sdml.fit(self.X, self.y, random_state=seed)
71-
res_1 = sdml.transform(self.X)
65+
def test_sdml_supervised(self):
66+
seed = np.random.RandomState(1234)
67+
sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5,
68+
use_cov=False)
69+
sdml.fit(self.X, self.y, random_state=seed)
70+
res_1 = sdml.transform(self.X)
7271

73-
seed = np.random.RandomState(1234)
74-
sdml = SDML_Supervised(num_constraints=1500)
75-
res_2 = sdml.fit_transform(self.X, self.y, random_state=seed)
72+
seed = np.random.RandomState(1234)
73+
sdml = SDML_Supervised(num_constraints=1500, balance_param=1e-5,
74+
use_cov=False)
75+
res_2 = sdml.fit_transform(self.X, self.y, random_state=seed)
7676

77-
assert_array_almost_equal(res_1, res_2)
77+
assert_array_almost_equal(res_1, res_2)
7878

7979
def test_nca(self):
8080
n = self.X.shape[0]

test/test_mahalanobis_mixin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,6 @@ def test_transformer_is_2D(estimator, build_dataset):
274274
"""Tests that the transformer of metric learners is 2D"""
275275
input_data, labels, _, X = build_dataset()
276276
model = clone(estimator)
277-
if model.__class__.__name__.startswith('SDML'):
278-
model.set_params(use_cov=False, balance_param=1e-3)
279277
set_random_state(model)
280278
# test that it works for X.shape[1] features
281279
model.fit(input_data, labels)

test/test_sklearn_compat.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
train_test_split, KFold)
1818
from sklearn.utils.testing import _get_args
1919

20-
from metric_learn._util import has_installed_skggm
2120
from test.test_utils import (metric_learners, ids_metric_learners,
2221
mock_preprocessor)
2322

@@ -75,10 +74,7 @@ def test_mmc(self):
7574
check_estimator(dMMC)
7675

7776
def test_sdml(self):
78-
if has_installed_skggm():
79-
check_estimator(dSDML)
80-
else:
81-
pass
77+
check_estimator(dSDML)
8278

8379
# This fails because the default num_chunks isn't data-dependent.
8480
# def test_rca(self):

test/test_transformer_metric_conversion.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from metric_learn import (
77
LMNN, NCA, LFDA, Covariance, MLKR,
88
LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised)
9-
from metric_learn._util import has_installed_skggm
109

1110

1211
class TestTransformerMetricConversion(unittest.TestCase):
@@ -43,13 +42,12 @@ def test_lmnn(self):
4342
L = lmnn.transformer_
4443
assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix())
4544

46-
if has_installed_skggm():
47-
def test_sdml_supervised(self):
48-
seed = np.random.RandomState(1234)
49-
sdml = SDML_Supervised(num_constraints=1500)
50-
sdml.fit(self.X, self.y, random_state=seed)
51-
L = sdml.transformer_
52-
assert_array_almost_equal(L.T.dot(L), sdml.get_mahalanobis_matrix())
45+
def test_sdml_supervised(self):
46+
seed = np.random.RandomState(1234)
47+
sdml = SDML_Supervised(num_constraints=1500)
48+
sdml.fit(self.X, self.y, random_state=seed)
49+
L = sdml.transformer_
50+
assert_array_almost_equal(L.T.dot(L), sdml.get_mahalanobis_matrix())
5351

5452
def test_nca(self):
5553
n = self.X.shape[0]

test/test_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from sklearn.base import clone
1010
from metric_learn._util import (check_input, make_context, preprocess_tuples,
1111
make_name, preprocess_points,
12-
check_collapsed_pairs, validate_vector,
13-
has_installed_skggm)
12+
check_collapsed_pairs, validate_vector)
1413
from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA,
1514
LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised,
1615
MMC_Supervised, RCA_Supervised, SDML_Supervised,
@@ -105,9 +104,7 @@ def build_quadruplets(with_preprocessor=False):
105104

106105
pairs_learners = [(ITML(), build_pairs),
107106
(MMC(max_iter=2), build_pairs), # max_iter=2 for faster
108-
]
109-
if has_installed_skggm():
110-
pairs_learners.append(((SDML(), build_pairs)))
107+
(SDML(use_cov=False, balance_param=1e-5), build_pairs)]
111108
ids_pairs_learners = list(map(lambda x: x.__class__.__name__,
112109
[learner for (learner, _) in
113110
pairs_learners]))
@@ -121,9 +118,8 @@ def build_quadruplets(with_preprocessor=False):
121118
(LSML_Supervised(), build_classification),
122119
(MMC_Supervised(max_iter=5), build_classification),
123120
(RCA_Supervised(num_chunks=10), build_classification),
124-
]
125-
if has_installed_skggm():
126-
classifiers.append(((SDML_Supervised(), build_classification)))
121+
(SDML_Supervised(use_cov=False, balance_param=1e-5),
122+
build_classification)]
127123
ids_classifiers = list(map(lambda x: x.__class__.__name__,
128124
[learner for (learner, _) in
129125
classifiers]))

0 commit comments

Comments
 (0)