-
Notifications
You must be signed in to change notification settings - Fork 229
[MRG] FIX: make proposal for sdml formulation #162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 60 commits
02cc937
aebd47f
40b2c88
fb04cc9
518d6e8
c912e93
8f0b113
c57a35a
f0eb938
821db0b
bd2862d
c6a2daa
93d790e
cae6c28
5d673ba
e8a28d5
e740702
333675b
1a6e97b
7cecf27
5303e1a
377760a
0a46ad5
391d773
6654769
ac4e18a
458d646
fd7c9fb
e118cd8
b0c4753
13146d8
db4a799
1011391
5ea7ba0
45d3b7b
dbf5257
4b0bae9
f3c690e
57b0567
04316b2
b641641
fedfb8e
f0bbf6d
520d7c2
0437c62
be1a5e6
56efa09
142eea9
fcfd44c
019e28b
04a5107
be3a2ad
1ee8d1f
03f4158
e621e27
0086c98
001600e
8c50a0d
e4132d6
b3bf6a8
49f3b9e
e1664c7
187e22c
60866cb
eb95719
4d61dba
71a02e0
1e6d440
a7ed1bb
31072d3
000f29a
169dccf
bfb0f8f
6f5666b
0973ef2
1c28ecd
df2ae9c
9683934
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ Alternately, download the source repository and run: | |
**Dependencies** | ||
|
||
- Python 2.7+, 3.4+ | ||
- numpy, scipy, scikit-learn | ||
- numpy, scipy, scikit-learn, and skggm (commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_) for `SDML` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment here as https://github.com/metric-learn/metric-learn/pull/162/files#r264683385 |
||
- (for running the examples only: matplotlib) | ||
|
||
**Notes** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,14 @@ def vector_norm(X): | |
return np.linalg.norm(X, axis=1) | ||
|
||
|
||
def has_installed_skggm(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this needs to be a function, as the answer isn't going to change during execution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's right thanks |
||
try: | ||
import inverse_covariance | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
|
||
def check_input(input_data, y=None, preprocessor=None, | ||
type_of_inputs='classic', tuple_size=None, accept_sparse=False, | ||
dtype='numeric', order=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,6 @@ def wrap_pairs(X, constraints): | |
c = np.array(constraints[2]) | ||
d = np.array(constraints[3]) | ||
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) | ||
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))]) | ||
y = np.hstack([np.ones((len(a),)), - np.ones((len(c),))]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact we should never return a column vector but a line vector (these are the ones scikit-learn likes to work on) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe be simpler to do:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, thanks |
||
pairs = X[constraints] | ||
return pairs, y |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,15 @@ | |
import warnings | ||
import numpy as np | ||
from sklearn.base import TransformerMixin | ||
from sklearn.covariance import graph_lasso | ||
from sklearn.utils.extmath import pinvh | ||
from scipy.linalg import pinvh | ||
from sklearn.covariance import graphical_lasso | ||
from sklearn.exceptions import ConvergenceWarning | ||
|
||
from .base_metric import MahalanobisMixin, _PairsClassifierMixin | ||
from .constraints import Constraints, wrap_pairs | ||
from ._util import transformer_from_metric | ||
from ._util import transformer_from_metric, has_installed_skggm | ||
if has_installed_skggm(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer a simpler conditional import here: try:
from inverse_covariance import quic
except ImportError:
HAS_SKGGM = False
else:
HAS_SKGGM = True There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd need to duplicate the logic in the tests, but I'm fine with that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed |
||
from inverse_covariance import quic | ||
|
||
|
||
class _BaseSDML(MahalanobisMixin): | ||
|
@@ -52,24 +55,50 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, | |
super(_BaseSDML, self).__init__(preprocessor) | ||
|
||
def _fit(self, pairs, y): | ||
if not has_installed_skggm(): | ||
msg = ("Warning, skggm is not installed, so SDML will use " | ||
"scikit-learn's graphical_lasso method. It can fail to converge" | ||
"on some non SPD matrices where skggm would converge. If so, " | ||
"try to install skggm. (see the README.md for the right " | ||
"version.)") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can catch the case where scikit-learn's version fails and emit the warning then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, here is the new version I'll commit. When using scikit-learn's graphical lasso, we try it and if an error is returned or the result is not finite, we raise a warning that will be printed before the error (if there is an error), or before returning M (if there is no error but there are NaNs) : Tell me what you think if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
try:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
error = None
except FloatingPointError as e:
error = e
if not np.isfinite(M).all() or error is not None:
msg = ("Scikit-learn's graphical lasso has failed to converge. "
"Package skggm's graphical lasso can sometimes converge on "
"non SPD cases where scikit-learn's graphical lasso fails to "
"converge. Try to install skggm and rerun the algorithm. (see "
"the README.md for the right version.)")
warnings.warn(msg)
if error is not None:
raise(error) EDIT: if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
try:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
except FloatingPointError as e:
msg = ("Scikit-learn's graphical lasso has failed to converge. "
"Package skggm's graphical lasso can sometimes converge on "
"non SPD cases where scikit-learn's graphical lasso fails to "
"converge. Try to install skggm and rerun the algorithm. (see "
"the README.md for the right version of skggm.)")
warnings.warn(msg)
raise(e) (in fact it's skggm's graphical lasso that throws nans, scikit-learn's graphical lasso will return FloatingPointError in case of error (i didn't find cases where it would give nans) so it's better to stick to the case we know) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, trying to come up with the right tests, I realized the following: Example (go in debug mode or put a print statement in SDML to see the result of the graphical lasso) from metric_learn import SDML
import numpy as np
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., 50.], [0., -60]]])
y_pairs = [1, -1]
sdml = SDML(use_cov=False, balance_param=100,verbose=True)
diff = pairs[:, 0] - pairs[:, 1]
emp_cov = np.identity(pairs.shape[2]) + 100 * (diff.T * y_pairs).dot(diff)
print(emp_cov)
sdml.fit(pairs, y_pairs)
print(sdml.get_mahalanobis_matrix()) Returns:
And if we print the result of graphical lasso (note that it's the inverse of the initial matrix):
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I'll go for something like this: try:
if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
except Exception as e:
raised_error = e
not_spd = False # not_spd not applicable so we set to False
if raised_error is not None or not_spd:
msg = ("There was a problem in SDML when using {}'s graphical "
"lasso.").format("skggm" if HAS_SKGGM else "scikit-learn")
if not HAS_SKGGM:
skggm_advice = ("skggm's graphical lasso can sometimes converge "
"on non SPD cases where scikit-learn's graphical "
"lasso fails to converge. Try to install skggm and "
"rerun the algorithm. (See the README.md for the "
"right version of skggm.)")
msg += skggm_advice
if raised_error is not None:
msg += "The following error message was thrown: {}.".format(
raised_error)
raise RuntimeError(msg) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add infinite values that can be returned by skggm as a failure case too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, in commit eb95719 |
||
warnings.warn(msg) | ||
else: | ||
print("SDML will use skggm's solver.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should only print this if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And maybe clarify: skggm's graphical lasso solver There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and maybe print something similar when sklearn solver is used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed |
||
pairs, y = self._prepare_inputs(pairs, y, | ||
type_of_inputs='tuples') | ||
|
||
# set up prior M | ||
if self.use_cov: | ||
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) | ||
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) | ||
prior = pinvh(np.atleast_2d(np.cov(X, rowvar=False))) | ||
else: | ||
M = np.identity(pairs.shape[2]) | ||
prior = np.identity(pairs.shape[2]) | ||
diff = pairs[:, 0] - pairs[:, 1] | ||
loss_matrix = (diff.T * y).dot(diff) | ||
P = M + self.balance_param * loss_matrix | ||
emp_cov = pinvh(P) | ||
# hack: ensure positive semidefinite | ||
emp_cov = emp_cov.T.dot(emp_cov) | ||
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) | ||
|
||
self.transformer_ = transformer_from_metric(M) | ||
emp_cov = pinvh(prior) + self.balance_param * loss_matrix | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems odd to round-trip the covariance matrix through two There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, done |
||
|
||
# our initialization will be the matrix with emp_cov's eigenvalues, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an init that we talked about with @bellet, that I found worked better (allowed tests to pass when with identity I had a lot of Linalg Error) |
||
# with a constant added so that they are all positive (plus an epsilon | ||
# to ensure definiteness). This is empirical. | ||
w, V = np.linalg.eigh(emp_cov) | ||
if any(w < 0.): | ||
warnings.warn("Warning, the input matrix of graphical lasso is not " | ||
"positive semi-definite (PSD). The algorithm may diverge, " | ||
"and lead to degenerate solutions. " | ||
"To prevent that, try to decrease the balance parameter " | ||
"`balance_param` and/or to set use_covariance=False.", | ||
ConvergenceWarning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with it. |
||
sigma0 = (V * (w - min(0, np.min(w)) + 1e-10)).dot(V.T) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe simpler: min_eigval = w.min()
if min_eigval < 0:
warnings.warn(...)
min_eigval = 0
w += 1e-10 - min_eigval
sigma0 = (V * w).dot(V.T) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the heuristic if min_eigval = w.min()
if min_eigval < 0:
warnings.warn(...)
w -= min_eigval # we translate the eigenvalues to make them all positive
w += 1e-10 # we add a small offset to avoid definiteness problems
sigma0 = (V * w).dot(V.T) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
if has_installed_skggm(): | ||
theta0 = pinvh(sigma0) | ||
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param, | ||
msg=self.verbose, | ||
Theta0=theta0, Sigma0=sigma0) | ||
else: | ||
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, | ||
verbose=self.verbose, | ||
cov_init=sigma0) | ||
self.transformer_ = transformer_from_metric(np.atleast_2d(M)) | ||
return self | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
extras_require=dict( | ||
docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'], | ||
demo=['matplotlib'], | ||
sdml=['skggm'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we specify a commit hash here? Or maybe since their latest release is 0.2.8, we could specify There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried indeed but didn't manage to make it work. But good idea, skggm>=0.2.9 seems is better than nothing here |
||
), | ||
test_suite='test', | ||
keywords=[ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,7 +96,7 @@ def check_is_distance_matrix(pairwise): | |
assert np.array_equal(pairwise, pairwise.T) # symmetry | ||
assert (pairwise.diagonal() == 0).all() # identity | ||
# triangular inequality | ||
tol = 1e-15 | ||
tol = 1e-12 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SDML was failing due to the harsh tolerance so I changed it but I think it's still reasonable |
||
assert (pairwise <= pairwise[:, :, np.newaxis] + | ||
pairwise[:, np.newaxis, :] + tol).all() | ||
|
||
|
@@ -281,5 +281,15 @@ def test_transformer_is_2D(estimator, build_dataset): | |
|
||
# test that it works for 1 feature | ||
trunc_data = input_data[..., :1] | ||
# we drop duplicates that might have been formed, i.e. of the form | ||
# aabc or abcc or aabb for quadruplets, and aa for pairs. | ||
slices = {4: [slice(0, 2), slice(2, 4)], 2: [slice(0, 2)]} | ||
if trunc_data.ndim == 3: | ||
for slice_idx in slices[trunc_data.shape[1]]: | ||
pairs = trunc_data[:, slice_idx, :] | ||
diffs = pairs[:, 1, :] - pairs[:, 0, :] | ||
to_keep = np.nonzero(diffs.ravel()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit difficult to parse. Why do we need these slices? I am a bit lazy to check but maybe it should be made more clear even if it is less efficient (this is a small 1D dataset anyway so we don't care) also maybe removing things that are very close to being the same (as opposed to exactly the same) would be more robust There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree it's difficult to parse, I'll change the test copying/pasting for the quadruplets/pairs case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, in commit 31072d3 |
||
trunc_data = trunc_data[to_keep] | ||
labels = labels[to_keep] | ||
model.fit(trunc_data, labels) | ||
assert model.transformer_.shape == (1, 1) # the transformer must be 2D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd call out skggm separately as an optional dependency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I agree, I forgot to change that after making skggm optional