|
12 | 12 | import warnings
|
13 | 13 | import numpy as np
|
14 | 14 | from sklearn.base import TransformerMixin
|
15 |
| -from sklearn.covariance import graph_lasso |
16 |
| -from sklearn.utils.extmath import pinvh |
| 15 | +from scipy.linalg import pinvh |
| 16 | +from sklearn.covariance import graphical_lasso |
| 17 | +from sklearn.exceptions import ConvergenceWarning |
17 | 18 |
|
18 | 19 | from .base_metric import MahalanobisMixin, _PairsClassifierMixin
|
19 | 20 | from .constraints import Constraints, wrap_pairs
|
20 | 21 | from ._util import transformer_from_metric
|
| 22 | +try: |
| 23 | + from inverse_covariance import quic |
| 24 | +except ImportError: |
| 25 | + HAS_SKGGM = False |
| 26 | +else: |
| 27 | + HAS_SKGGM = True |
21 | 28 |
|
22 | 29 |
|
23 | 30 | class _BaseSDML(MahalanobisMixin):
|
@@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
|
52 | 59 | super(_BaseSDML, self).__init__(preprocessor)
|
53 | 60 |
|
54 | 61 | def _fit(self, pairs, y):
|
| 62 | + if not HAS_SKGGM: |
| 63 | + if self.verbose: |
| 64 | + print("SDML will use scikit-learn's graphical lasso solver.") |
| 65 | + else: |
| 66 | + if self.verbose: |
| 67 | + print("SDML will use skggm's graphical lasso solver.") |
55 | 68 | pairs, y = self._prepare_inputs(pairs, y,
|
56 | 69 | type_of_inputs='tuples')
|
57 | 70 |
|
58 |
| - # set up prior M |
| 71 | + # set up (the inverse of) the prior M |
59 | 72 | if self.use_cov:
|
60 | 73 | X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
|
61 |
| - M = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) |
| 74 | + prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) |
62 | 75 | else:
|
63 |
| - M = np.identity(pairs.shape[2]) |
| 76 | + prior_inv = np.identity(pairs.shape[2]) |
64 | 77 | diff = pairs[:, 0] - pairs[:, 1]
|
65 | 78 | loss_matrix = (diff.T * y).dot(diff)
|
66 |
| - P = M + self.balance_param * loss_matrix |
67 |
| - emp_cov = pinvh(P) |
68 |
| - # hack: ensure positive semidefinite |
69 |
| - emp_cov = emp_cov.T.dot(emp_cov) |
70 |
| - _, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) |
71 |
| - |
72 |
| - self.transformer_ = transformer_from_metric(M) |
| 79 | + emp_cov = prior_inv + self.balance_param * loss_matrix |
| 80 | + |
| 81 | + # our initialization will be the matrix with emp_cov's eigenvalues, |
| 82 | + # with a constant added so that they are all positive (plus an epsilon |
| 83 | + # to ensure definiteness). This is empirical. |
| 84 | + w, V = np.linalg.eigh(emp_cov) |
| 85 | + min_eigval = np.min(w) |
| 86 | + if min_eigval < 0.: |
| 87 | + warnings.warn("Warning, the input matrix of graphical lasso is not " |
| 88 | + "positive semi-definite (PSD). The algorithm may diverge, " |
| 89 | + "and lead to degenerate solutions. " |
| 90 | + "To prevent that, try to decrease the balance parameter " |
| 91 | + "`balance_param` and/or to set use_covariance=False.", |
| 92 | + ConvergenceWarning) |
| 93 | + w -= min_eigval # we translate the eigenvalues to make them all positive |
| 94 | + w += 1e-10 # we add a small offset to avoid definiteness problems |
| 95 | + sigma0 = (V * w).dot(V.T) |
| 96 | + try: |
| 97 | + if HAS_SKGGM: |
| 98 | + theta0 = pinvh(sigma0) |
| 99 | + M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param, |
| 100 | + msg=self.verbose, |
| 101 | + Theta0=theta0, Sigma0=sigma0) |
| 102 | + else: |
| 103 | + _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, |
| 104 | + verbose=self.verbose, |
| 105 | + cov_init=sigma0) |
| 106 | + raised_error = None |
| 107 | + w_mahalanobis, _ = np.linalg.eigh(M) |
| 108 | + not_spd = any(w_mahalanobis < 0.) |
| 109 | + not_finite = not np.isfinite(M).all() |
| 110 | + except Exception as e: |
| 111 | + raised_error = e |
| 112 | + not_spd = False # not_spd not applicable here so we set to False |
| 113 | + not_finite = False # not_finite not applicable here so we set to False |
| 114 | + if raised_error is not None or not_spd or not_finite: |
| 115 | + msg = ("There was a problem in SDML when using {}'s graphical " |
| 116 | + "lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn") |
| 117 | + if not HAS_SKGGM: |
| 118 | + skggm_advice = (" skggm's graphical lasso can sometimes converge " |
| 119 | + "on non SPD cases where scikit-learn's graphical " |
| 120 | + "lasso fails to converge. Try to install skggm and " |
| 121 | + "rerun the algorithm (see the README.md for the " |
| 122 | + "right version of skggm).") |
| 123 | + msg += skggm_advice |
| 124 | + if raised_error is not None: |
| 125 | + msg += " The following error message was thrown: {}.".format( |
| 126 | + raised_error) |
| 127 | + raise RuntimeError(msg) |
| 128 | + |
| 129 | + self.transformer_ = transformer_from_metric(np.atleast_2d(M)) |
73 | 130 | return self
|
74 | 131 |
|
75 | 132 |
|
|
0 commit comments