Skip to content

Commit 60866cb

Browse files
author
William de Vazelhes
committed
Nitpick for concatenation and refactor HAS_SKGGM
1 parent 187e22c commit 60866cb

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

metric_learn/_util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@ def vector_norm(X):
1515
return np.linalg.norm(X, axis=1)
1616

1717

18-
def has_installed_skggm():
19-
try:
20-
import inverse_covariance
21-
return True
22-
except ImportError:
23-
return False
24-
25-
2618
def check_input(input_data, y=None, preprocessor=None,
2719
type_of_inputs='classic', tuple_size=None, accept_sparse=False,
2820
dtype='numeric', order=None,

metric_learn/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ def wrap_pairs(X, constraints):
9696
c = np.array(constraints[2])
9797
d = np.array(constraints[3])
9898
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
99-
y = np.hstack([np.ones((len(a),)), - np.ones((len(c),))])
99+
y = np.concatenate([np.ones_like(a), -np.ones_like(c)])
100100
pairs = X[constraints]
101101
return pairs, y

metric_learn/sdml.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,13 @@
1818

1919
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
2020
from .constraints import Constraints, wrap_pairs
21-
from ._util import transformer_from_metric, has_installed_skggm
22-
if has_installed_skggm():
21+
from ._util import transformer_from_metric
22+
try:
2323
from inverse_covariance import quic
24+
except ImportError:
25+
HAS_SKGGM = False
26+
else:
27+
HAS_SKGGM = True
2428

2529

2630
class _BaseSDML(MahalanobisMixin):
@@ -55,7 +59,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
5559
super(_BaseSDML, self).__init__(preprocessor)
5660

5761
def _fit(self, pairs, y):
58-
if not has_installed_skggm():
62+
if not HAS_SKGGM:
5963
msg = ("Warning, skggm is not installed, so SDML will use "
6064
"scikit-learn's graphical_lasso method. It can fail to converge"
6165
"on some non SPD matrices where skggm would converge. If so, "
@@ -89,7 +93,7 @@ def _fit(self, pairs, y):
8993
"`balance_param` and/or to set use_covariance=False.",
9094
ConvergenceWarning)
9195
sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T)
92-
if has_installed_skggm():
96+
if HAS_SKGGM:
9397
theta0 = pinvh(sigma0)
9498
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
9599
msg=self.verbose,

test/metric_learn_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
from sklearn.utils.testing import assert_warns_message
1111
from sklearn.exceptions import ConvergenceWarning
1212
from sklearn.utils.validation import check_X_y
13-
13+
try:
14+
from inverse_covariance import quic
15+
except ImportError:
16+
HAS_SKGGM = False
17+
else:
18+
HAS_SKGGM = True
1419
from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC,
1520
LSML_Supervised, ITML_Supervised, SDML_Supervised,
1621
RCA_Supervised, MMC_Supervised, SDML)
1722
# Import this specially for testing.
18-
from metric_learn._util import has_installed_skggm
1923
from metric_learn.constraints import wrap_pairs
2024
from metric_learn.lmnn import python_LMNN
2125

@@ -150,7 +154,7 @@ def test_no_twice_same_objective(capsys):
150154

151155
class TestSDML(MetricTestCase):
152156

153-
@pytest.mark.skipif(has_installed_skggm(),
157+
@pytest.mark.skipif(HAS_SKGGM,
154158
reason="The warning will be thrown only if skggm is "
155159
"not installed.")
156160
def test_raises_warning_msg_not_installed_skggm(self):
@@ -174,7 +178,7 @@ def test_raises_warning_msg_not_installed_skggm(self):
174178
sdml_supervised.fit(X, y)
175179
assert str(record[0].message) == msg
176180

177-
@pytest.mark.skipif(not has_installed_skggm(),
181+
@pytest.mark.skipif(not HAS_SKGGM,
178182
reason="It's only in the case where skggm is installed"
179183
"that no warning should be thrown.")
180184
def test_raises_no_warning_installed_skggm(self):
@@ -245,7 +249,7 @@ def test_sdml_converges_if_psd(self):
245249
sdml.fit(pairs, y)
246250
assert np.isfinite(sdml.get_mahalanobis_matrix()).all()
247251

248-
@pytest.mark.skipif(not has_installed_skggm(),
252+
@pytest.mark.skipif(not HAS_SKGGM,
249253
reason="sklearn's graphical_lasso can sometimes not "
250254
"work on some non SPD problems. We test that "
251255
"is works only if skggm is installed.")
@@ -258,7 +262,7 @@ def test_sdml_works_on_non_spd_pb_with_skggm(self):
258262
sdml.fit(X, y)
259263

260264

261-
@pytest.mark.skipif(not has_installed_skggm(),
265+
@pytest.mark.skipif(not HAS_SKGGM,
262266
reason='The message should be printed only if skggm is '
263267
'installed.')
264268
def test_verbose_has_installed_skggm_sdml(capsys):
@@ -273,7 +277,7 @@ def test_verbose_has_installed_skggm_sdml(capsys):
273277
assert "SDML will use skggm's solver." in out
274278

275279

276-
@pytest.mark.skipif(not has_installed_skggm(),
280+
@pytest.mark.skipif(not HAS_SKGGM,
277281
reason='The message should be printed only if skggm is '
278282
'installed.')
279283
def test_verbose_has_installed_skggm_sdml_supervised(capsys):

0 commit comments

Comments
 (0)