Skip to content

Commit e8c74d0

Browse files
wdevazelhesbellet
authored andcommitted
[MRG] FIX: sdml formulation and solvers (#162)
* FIX: make proposal for sdml formulation * MAINT clearer formulation to make the prior appear * MAINT call the prior prior * Use skggm instead of graphical lasso * Be more severe for the class separation * Put back verbose param * MAINT: make more explicit the fact that to use identity (i.e. an SPD matrix) as initialization * Add skggm as a requirement for SDML * Add skggm to required packages for travis * Also add cython as a dependency * FIX: install all except skggm and then skggm * Remove cython dependency * Install skggm only if we have at least python 3.6 * Should work if we want other versions superior to 3.6 * Fix bash >= which should be written -ge * Deal with tests when skggm is not installed and fix some PEP8 warnings * replace manual calls of algorithms with tuples_learners * Remove another call of SDML if skggm is not installed * FIX fix the test_error_message_tuple_size * FIX fix test_sdml_supervised * FIX: fix another sdml test * FIX quic call for python 2.7 * Fix quic import * Add Sigma0 initalization (both sigma zero and theta zero should be specified otherwise an error is returned * Deal with SDML making some tests fail * Remove epsilon that was unnecessary * FIX: use latest commit of skggm that fixes the non deterministic problem * MAINT: add message for SDML when not SPD * MAINT: add test for error message if skggm not installed * Try other syntax for installing the right commit of skggm * MAINT: make sklearn compat sdml test be run only if skggm is installed * Try another syntax for running travis * Better bash syntax * Fix tests by removing duplicates * FIX: fix for sdml by reducing balance parameter * FIX: update code to work with old version of numpy that does not have axis for unique * Remove the need for skggm * Update travis not to use skggm * Add a stable init for sklearn checks * FIX test_sdml_supervised * Revert "Update travis not to use skggm" This reverts commit 57b0567. * Add fallback on skggm * FIX: fix versions comparison and tests * MAINT: improve test of no warning * FIX: fix wrap pairs that was returning column y (we need line y), and fix the example for SDML to not raise another warning * FIX: force travis to do the right check * TST: add non SPD test that works with skggm's quic but not sklearn's graphical_lasso * Try again travis this time installing cython * Try to make travis work with build_essential * Try with installing liblapack * TST: fix tests for when skggm is not installed * TST: use better pytest skipif syntax * FIX: fix broken link in README.md * use rst syntax for link * use rst syntax for link * use rst syntax for link * MAINT: remove test_sdml that was remaining from drafts tests * TST: remove skipping SDML in test_cross_validation_manual_vs_scikit * FIX link also in getting started * Put back right indent * Remove unnecessary changes * Nitpick for concatenation and refactor HAS_SKGGM * ENH: Deal better with errors and skggm/scikit-learn * Better creation of prior * Simplification for init of sdml * Put skggm as optional * Specify skggm version * TST: make test about 1 feature arrays more readable * DOC: fix rst formatting * DOC: reformulated skggm optional dependency * TST: give an example for sdml_supervised with skggm where it indeed fails * TST: fix test that fails weirdly when executing the whole test file and not just the test * Revert "TST: fix test that fails weirdly when executing the whole test file and not just the test" This reverts commit 6f5666b. * Add coverage for all versions of python * Install pytest-cov for all versions
1 parent 4e37d7c commit e8c74d0

12 files changed

+375
-63
lines changed

.travis.yml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ cache: pip
44
python:
55
- "2.7"
66
- "3.4"
7+
- "3.6"
78
before_install:
9+
- sudo apt-get install liblapack-dev
810
- pip install --upgrade pip pytest
9-
- pip install wheel
10-
- pip install codecov
11-
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
12-
then pip install pytest-cov;
11+
- pip install wheel cython numpy scipy scikit-learn codecov pytest-cov
12+
- if [[ ($TRAVIS_PYTHON_VERSION == "3.6") ||
13+
($TRAVIS_PYTHON_VERSION == "2.7")]]; then
14+
pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8;
1315
fi
14-
- pip install numpy scipy scikit-learn
1516
script:
16-
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
17-
then pytest test --cov;
18-
else pytest test;
19-
fi
17+
# we do coverage for all versions so that codecov will merge them: this
18+
# way we will see that both paths (with or without skggm) are tested
19+
- pytest test --cov;
2020
after_success:
2121
- bash <(curl -s https://codecov.io/bash)
22+

README.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ Metric Learning algorithms in Python.
2121

2222
- Python 2.7+, 3.4+
2323
- numpy, scipy, scikit-learn
24-
- (for running the examples only: matplotlib)
24+
25+
**Optional dependencies**
26+
27+
- For SDML, using skggm will allow the algorithm to solve problematic cases
28+
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
29+
- For running the examples only: matplotlib
2530

2631
**Installation/Setup**
2732

doc/getting_started.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ Alternately, download the source repository and run:
1616

1717
- Python 2.7+, 3.4+
1818
- numpy, scipy, scikit-learn
19-
- (for running the examples only: matplotlib)
19+
20+
**Optional dependencies**
21+
22+
- For SDML, using skggm will allow the algorithm to solve problematic cases
23+
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
24+
- For running the examples only: matplotlib
2025

2126
**Notes**
2227

metric_learn/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,6 @@ def wrap_pairs(X, constraints):
9696
c = np.array(constraints[2])
9797
d = np.array(constraints[3])
9898
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
99-
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
99+
y = np.concatenate([np.ones_like(a), -np.ones_like(c)])
100100
pairs = X[constraints]
101101
return pairs, y

metric_learn/sdml.py

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@
1212
import warnings
1313
import numpy as np
1414
from sklearn.base import TransformerMixin
15-
from sklearn.covariance import graph_lasso
16-
from sklearn.utils.extmath import pinvh
15+
from scipy.linalg import pinvh
16+
from sklearn.covariance import graphical_lasso
17+
from sklearn.exceptions import ConvergenceWarning
1718

1819
from .base_metric import MahalanobisMixin, _PairsClassifierMixin
1920
from .constraints import Constraints, wrap_pairs
2021
from ._util import transformer_from_metric
22+
try:
23+
from inverse_covariance import quic
24+
except ImportError:
25+
HAS_SKGGM = False
26+
else:
27+
HAS_SKGGM = True
2128

2229

2330
class _BaseSDML(MahalanobisMixin):
@@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
5259
super(_BaseSDML, self).__init__(preprocessor)
5360

5461
def _fit(self, pairs, y):
62+
if not HAS_SKGGM:
63+
if self.verbose:
64+
print("SDML will use scikit-learn's graphical lasso solver.")
65+
else:
66+
if self.verbose:
67+
print("SDML will use skggm's graphical lasso solver.")
5568
pairs, y = self._prepare_inputs(pairs, y,
5669
type_of_inputs='tuples')
5770

58-
# set up prior M
71+
# set up (the inverse of) the prior M
5972
if self.use_cov:
6073
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
61-
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
74+
prior_inv = np.atleast_2d(np.cov(X, rowvar=False))
6275
else:
63-
M = np.identity(pairs.shape[2])
76+
prior_inv = np.identity(pairs.shape[2])
6477
diff = pairs[:, 0] - pairs[:, 1]
6578
loss_matrix = (diff.T * y).dot(diff)
66-
P = M + self.balance_param * loss_matrix
67-
emp_cov = pinvh(P)
68-
# hack: ensure positive semidefinite
69-
emp_cov = emp_cov.T.dot(emp_cov)
70-
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)
71-
72-
self.transformer_ = transformer_from_metric(M)
79+
emp_cov = prior_inv + self.balance_param * loss_matrix
80+
81+
# our initialization will be the matrix with emp_cov's eigenvalues,
82+
# with a constant added so that they are all positive (plus an epsilon
83+
# to ensure definiteness). This is empirical.
84+
w, V = np.linalg.eigh(emp_cov)
85+
min_eigval = np.min(w)
86+
if min_eigval < 0.:
87+
warnings.warn("Warning, the input matrix of graphical lasso is not "
88+
"positive semi-definite (PSD). The algorithm may diverge, "
89+
"and lead to degenerate solutions. "
90+
"To prevent that, try to decrease the balance parameter "
91+
"`balance_param` and/or to set use_covariance=False.",
92+
ConvergenceWarning)
93+
w -= min_eigval # we translate the eigenvalues to make them all positive
94+
w += 1e-10 # we add a small offset to avoid definiteness problems
95+
sigma0 = (V * w).dot(V.T)
96+
try:
97+
if HAS_SKGGM:
98+
theta0 = pinvh(sigma0)
99+
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
100+
msg=self.verbose,
101+
Theta0=theta0, Sigma0=sigma0)
102+
else:
103+
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
104+
verbose=self.verbose,
105+
cov_init=sigma0)
106+
raised_error = None
107+
w_mahalanobis, _ = np.linalg.eigh(M)
108+
not_spd = any(w_mahalanobis < 0.)
109+
not_finite = not np.isfinite(M).all()
110+
except Exception as e:
111+
raised_error = e
112+
not_spd = False # not_spd not applicable here so we set to False
113+
not_finite = False # not_finite not applicable here so we set to False
114+
if raised_error is not None or not_spd or not_finite:
115+
msg = ("There was a problem in SDML when using {}'s graphical "
116+
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn")
117+
if not HAS_SKGGM:
118+
skggm_advice = (" skggm's graphical lasso can sometimes converge "
119+
"on non SPD cases where scikit-learn's graphical "
120+
"lasso fails to converge. Try to install skggm and "
121+
"rerun the algorithm (see the README.md for the "
122+
"right version of skggm).")
123+
msg += skggm_advice
124+
if raised_error is not None:
125+
msg += " The following error message was thrown: {}.".format(
126+
raised_error)
127+
raise RuntimeError(msg)
128+
129+
self.transformer_ = transformer_from_metric(np.atleast_2d(M))
73130
return self
74131

75132

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
extras_require=dict(
3939
docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'],
4040
demo=['matplotlib'],
41+
sdml=['skggm>=0.2.9']
4142
),
4243
test_suite='test',
4344
keywords=[

0 commit comments

Comments
 (0)