diff --git a/README.rst b/README.rst index 9bb762b4..1e8adbe7 100644 --- a/README.rst +++ b/README.rst @@ -15,6 +15,7 @@ Metric Learning algorithms in Python. - Local Fisher Discriminant Analysis (LFDA) - Relative Components Analysis (RCA) - Metric Learning for Kernel Regression (MLKR) +- Mahalanobis Metric for Clustering (MMC) **Dependencies** diff --git a/metric_learn/__init__.py b/metric_learn/__init__.py index 5a7508c0..b86c10e1 100644 --- a/metric_learn/__init__.py +++ b/metric_learn/__init__.py @@ -10,3 +10,4 @@ from .lfda import LFDA from .rca import RCA, RCA_Supervised from .mlkr import MLKR +from .mmc import MMC, MMC_Supervised diff --git a/metric_learn/_util.py b/metric_learn/_util.py new file mode 100644 index 00000000..b34860d6 --- /dev/null +++ b/metric_learn/_util.py @@ -0,0 +1,12 @@ +import numpy as np + + +# hack around lack of axis kwarg in older numpy versions +try: + np.linalg.norm([[4]], axis=1) +except TypeError: + def vector_norm(X): + return np.apply_along_axis(np.linalg.norm, 1, X) +else: + def vector_norm(X): + return np.linalg.norm(X, axis=1) \ No newline at end of file diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 4c154ad4..7169fb36 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -21,6 +21,7 @@ from .base_metric import BaseMetricLearner from .constraints import Constraints +from ._util import vector_norm class ITML(BaseMetricLearner): @@ -54,10 +55,10 @@ def _process_inputs(self, X, constraints, bounds): self.X_ = X = check_array(X) # check to make sure that no two constrained vectors are identical a,b,c,d = constraints - ident = _vector_norm(X[a] - X[b]) > 1e-9 - a, b = a[ident], b[ident] - ident = _vector_norm(X[c] - X[d]) > 1e-9 - c, d = c[ident], d[ident] + no_ident = vector_norm(X[a] - X[b]) > 1e-9 + a, b = a[no_ident], b[no_ident] + no_ident = vector_norm(X[c] - X[d]) > 1e-9 + c, d = c[no_ident], d[no_ident] # init bounds if bounds is None: self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) @@ -138,16 +139,6 @@ def fit(self, X, constraints, bounds=None): def metric(self): return self.A_ -# hack around lack of axis kwarg in older numpy versions -try: - np.linalg.norm([[4]], axis=1) -except TypeError: - def _vector_norm(X): - return np.apply_along_axis(np.linalg.norm, 1, X) -else: - def _vector_norm(X): - return np.linalg.norm(X, axis=1) - class ITML_Supervised(ITML): """Information Theoretic Metric Learning (ITML)""" diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py new file mode 100644 index 00000000..7760e1b1 --- /dev/null +++ b/metric_learn/mmc.py @@ -0,0 +1,436 @@ +""" +Mahalanobis Metric Learning with Application for Clustering with Side-Information, Xing et al., NIPS 2002 + +MMC minimizes the sum of squared distances between similar examples, +while enforcing the sum of distances between dissimilar examples to be +greater than a certain margin. +This leads to a convex and, thus, local-minima-free optimization problem +that can be solved efficiently. +However, the algorithm involves the computation of eigenvalues, which is the +main speed-bottleneck. +Since it has initially been designed for clustering applications, one of the +implicit assumptions of MMC is that all classes form a compact set, i.e., +follow a unimodal distribution, which restricts the possible use-cases of +this method. However, it is one of the earliest and a still often cited technique. + +Adapted from Matlab code at http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz +""" + +from __future__ import print_function, absolute_import, division +import numpy as np +from six.moves import xrange +from sklearn.metrics import pairwise_distances +from sklearn.utils.validation import check_array, check_X_y + +from .base_metric import BaseMetricLearner +from .constraints import Constraints +from ._util import vector_norm + + + +class MMC(BaseMetricLearner): + """Mahalanobis Metric for Clustering (MMC)""" + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, + A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + """Initialize MMC. + Parameters + ---------- + max_iter : int, optional + max_proj : int, optional + convergence_threshold : float, optional + A0 : (d x d) matrix, optional + initial metric, defaults to identity + only the main diagonal is taken if `diagonal == True` + diagonal : bool, optional + if True, a diagonal metric will be learned, + i.e., a simple scaling of dimensions + diagonal_c : float, optional + weight of the dissimilarity constraint for diagonal + metric learning + verbose : bool, optional + if True, prints information while learning + """ + self.max_iter = max_iter + self.max_proj = max_proj + self.convergence_threshold = convergence_threshold + self.A0 = A0 + self.diagonal = diagonal + self.diagonal_c = diagonal_c + self.verbose = verbose + + def fit(self, X, constraints): + """Learn the MMC model. + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + constraints : 4-tuple of arrays + (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) + dissimilar pairs + """ + constraints = self._process_inputs(X, constraints) + if self.diagonal: + return self._fit_diag(X, constraints) + else: + return self._fit_full(X, constraints) + + def _process_inputs(self, X, constraints): + + self.X_ = X = check_array(X) + + # check to make sure that no two constrained vectors are identical + a,b,c,d = constraints + no_ident = vector_norm(X[a] - X[b]) > 1e-9 + a, b = a[no_ident], b[no_ident] + no_ident = vector_norm(X[c] - X[d]) > 1e-9 + c, d = c[no_ident], d[no_ident] + if len(a) == 0: + raise ValueError('No non-trivial similarity constraints given for MMC.') + if len(c) == 0: + raise ValueError('No non-trivial dissimilarity constraints given for MMC.') + + # init metric + if self.A0 is None: + self.A_ = np.identity(X.shape[1]) + if not self.diagonal: + # Don't know why division by 10... it's in the original code + # and seems to affect the overall scale of the learned metric. + self.A_ /= 10.0 + else: + self.A_ = check_array(self.A0) + + return a,b,c,d + + def _fit_full(self, X, constraints): + """Learn full metric using MMC. + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + constraints : 4-tuple of arrays + (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) + dissimilar pairs + """ + a,b,c,d = constraints + num_pos = len(a) + num_neg = len(c) + num_samples, num_dim = X.shape + + error1 = error2 = 1e10 + eps = 0.01 # error-bound of iterative projection on C1 and C2 + A = self.A_ + + # Create weight vector from similar samples + pos_diff = X[a] - X[b] + w = np.einsum('ij,ik->jk', pos_diff, pos_diff).ravel() + # `w` is the sum of all outer products of the rows in `pos_diff`. + # The above `einsum` is equivalent to the much more inefficient: + # w = np.apply_along_axis( + # lambda x: np.outer(x,x).ravel(), + # 1, + # X[a] - X[b] + # ).sum(axis = 0) + t = w.dot(A.ravel()) / 100.0 + + w_norm = np.linalg.norm(w) + w1 = w / w_norm # make `w` a unit vector + t1 = t / w_norm # distance from origin to `w^T*x=t` plane + + cycle = 1 + alpha = 0.1 # initial step size along gradient + + grad1 = self._fS1(X, a, b, A) # gradient of similarity constraint function + grad2 = self._fD1(X, c, d, A) # gradient of dissimilarity constraint function + M = self._grad_projection(grad1, grad2) # gradient of fD1 orthogonal to fS1 + + A_old = A.copy() + + for cycle in xrange(self.max_iter): + + # projection of constraints C1 and C2 + satisfy = False + + for it in xrange(self.max_proj): + + # First constraint: + # f(A) = \sum_{i,j \in S} d_ij' A d_ij <= t (1) + # (1) can be rewritten as a linear constraint: w^T x = t, + # where x is the unrolled matrix of A, + # w is also an unrolled matrix of W where + # W_{kl}= \sum_{i,j \in S}d_ij^k * d_ij^l + x0 = A.ravel() + if w.dot(x0) <= t: + x = x0 + else: + x = x0 + (t1 - w1.dot(x0)) * w1 + A[:] = x.reshape(num_dim, num_dim) + + # Second constraint: + # PSD constraint A >= 0 + # project A onto domain A>0 + l, V = np.linalg.eigh((A + A.T) / 2) + A[:] = np.dot(V * np.maximum(0, l[None,:]), V.T) + + fDC2 = w.dot(A.ravel()) + error2 = (fDC2 - t) / t + if error2 < eps: + satisfy = True + break + + # third constraint: gradient ascent + # max: g(A) >= 1 + # here we suppose g(A) = fD(A) = \sum_{I,J \in D} sqrt(d_ij' A d_ij) + + obj_previous = self._fD(X, c, d, A_old) # g(A_old) + obj = self._fD(X, c, d, A) # g(A) + + if satisfy and (obj > obj_previous or cycle == 0): + + # If projection of 1 and 2 is successful, and such projection + # improves objective function, slightly increase learning rate + # and update from the current A. + alpha *= 1.05 + A_old[:] = A + grad2 = self._fS1(X, a, b, A) + grad1 = self._fD1(X, c, d, A) + M = self._grad_projection(grad1, grad2) + A += alpha * M + + else: + + # If projection of 1 and 2 failed, or obj <= obj_previous due + # to projection of 1 and 2, shrink learning rate and re-update + # from the previous A. + alpha /= 2 + A[:] = A_old + alpha * M + + delta = np.linalg.norm(alpha * M) / np.linalg.norm(A_old) + if delta < self.convergence_threshold: + break + if self.verbose: + print('mmc iter: %d, conv = %f, projections = %d' % (cycle, delta, it+1)) + + if delta > self.convergence_threshold: + self.converged_ = False + if self.verbose: + print('mmc did not converge, conv = %f' % (delta,)) + else: + self.converged_ = True + if self.verbose: + print('mmc converged at iter %d, conv = %f' % (cycle, delta)) + self.A_[:] = A_old + self.n_iter_ = cycle + return self + + def _fit_diag(self, X, constraints): + """Learn diagonal metric using MMC. + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + constraints : 4-tuple of arrays + (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) + dissimilar pairs + """ + a,b,c,d = constraints + num_pos = len(a) + num_neg = len(c) + num_samples, num_dim = X.shape + + s_sum = np.sum((X[a] - X[b]) ** 2, axis=0) + + it = 0 + error = 1.0 + eps = 1e-6 + reduction = 2.0 + w = np.diag(self.A_).copy() + + while error > self.convergence_threshold: + + fD0, fD_1st_d, fD_2nd_d = self._D_constraint(X, c, d, w) + obj_initial = np.dot(s_sum, w) + self.diagonal_c * fD0 + fS_1st_d = s_sum # first derivative of the similarity constraints + + gradient = fS_1st_d - self.diagonal_c * fD_1st_d # gradient of the objective + hessian = -self.diagonal_c * fD_2nd_d + eps * np.eye(num_dim) # Hessian of the objective + step = np.dot(np.linalg.inv(hessian), gradient); + + # Newton-Rapshon update + # search over optimal lambda + lambd = 1 # initial step-size + w_tmp = np.maximum(0, w - lambd * step) + + obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + obj_previous = obj * 1.1 # just to get the while-loop started + + inner_it = 0 + while obj < obj_previous: + obj_previous = obj + w_previous = w_tmp.copy() + lambd /= reduction + w_tmp = np.maximum(0, w - lambd * step) + obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + inner_it += 1 + + w[:] = w_previous + error = np.abs((obj_previous - obj_initial) / obj_previous) + if self.verbose: + print('mmc iter: %d, conv = %f' % (it, error)) + it += 1 + + self.A_ = np.diag(w) + return self + + def _fD(self, X, c, d, A): + """The value of the dissimilarity constraint function. + + f = f(\sum_{ij \in D} distance(x_i, x_j)) + i.e. distance can be L1: \sqrt{(x_i-x_j)A(x_i-x_j)'} + """ + diff = X[c] - X[d] + return np.log(np.sum(np.sqrt(np.sum(np.dot(diff, A) * diff, axis=1))) + 1e-6) + + def _fD1(self, X, c, d, A): + """The gradient of the dissimilarity constraint function w.r.t. A. + + For example, let distance by L1 norm: + f = f(\sum_{ij \in D} \sqrt{(x_i-x_j)A(x_i-x_j)'}) + df/dA_{kl} = f'* d(\sum_{ij \in D} \sqrt{(x_i-x_j)^k*(x_i-x_j)^l})/dA_{kl} + + Note that d_ij*A*d_ij' = tr(d_ij*A*d_ij') = tr(d_ij'*d_ij*A) + so, d(d_ij*A*d_ij')/dA = d_ij'*d_ij + df/dA = f'(\sum_{ij \in D} \sqrt{tr(d_ij'*d_ij*A)}) + * 0.5*(\sum_{ij \in D} (1/sqrt{tr(d_ij'*d_ij*A)})*(d_ij'*d_ij)) + """ + dim = X.shape[1] + diff = X[c] - X[d] + # outer products of all rows in `diff` + M = np.einsum('ij,ik->ijk', diff, diff) + # faster version of: dist = np.sqrt(np.sum(M * A[None,:,:], axis=(1,2))) + dist = np.sqrt(np.einsum('ijk,jk', M, A)) + # faster version of: sum_deri = np.sum(M / (2 * (dist[:,None,None] + 1e-6)), axis=0) + sum_deri = np.einsum('ijk,i->jk', M, 0.5 / (dist + 1e-6)) + sum_dist = dist.sum() + return sum_deri / (sum_dist + 1e-6) + + def _fS1(self, X, a, b, A): + """The gradient of the similarity constraint function w.r.t. A. + + f = \sum_{ij}(x_i-x_j)A(x_i-x_j)' = \sum_{ij}d_ij*A*d_ij' + df/dA = d(d_ij*A*d_ij')/dA + + Note that d_ij*A*d_ij' = tr(d_ij*A*d_ij') = tr(d_ij'*d_ij*A) + so, d(d_ij*A*d_ij')/dA = d_ij'*d_ij + """ + dim = X.shape[1] + diff = X[a] - X[b] + return np.einsum('ij,ik->jk', diff, diff) # sum of outer products of all rows in `diff` + + def _grad_projection(self, grad1, grad2): + grad2 = grad2 / np.linalg.norm(grad2) + gtemp = grad1 - np.sum(grad1 * grad2) * grad2 + gtemp /= np.linalg.norm(gtemp) + return gtemp + + def _D_objective(self, X, c, d, w): + return np.log(np.sum(np.sqrt(np.sum(((X[c] - X[d]) ** 2) * w[None,:], axis=1) + 1e-6))) + + def _D_constraint(self, X, c, d, w): + """Compute the value, 1st derivative, second derivative (Hessian) of + a dissimilarity constraint function gF(sum_ij distance(d_ij A d_ij)) + where A is a diagonal matrix (in the form of a column vector 'w'). + """ + diff = X[c] - X[d] + diff_sq = diff * diff + dist = np.sqrt(diff_sq.dot(w)) + sum_deri1 = np.einsum('ij,i', diff_sq, 0.5 / np.maximum(dist, 1e-6)) + sum_deri2 = np.einsum( + 'ij,ik->jk', + diff_sq, + diff_sq / (-4 * np.maximum(1e-6, dist**3))[:,None] + ) + sum_dist = dist.sum() + return ( + np.log(sum_dist), + sum_deri1 / sum_dist, + sum_deri2 / sum_dist - np.outer(sum_deri1, sum_deri1) / (sum_dist * sum_dist) + ) + + def metric(self): + return self.A_ + + def transformer(self): + """Computes the transformation matrix from the Mahalanobis matrix. + L = V.T * w^(-1/2), with A = V*w*V.T being the eigenvector decomposition of A with + the eigenvalues in the diagonal matrix w and the columns of V being the eigenvectors. + + The Cholesky decomposition cannot be applied here, since MMC learns only a positive + *semi*-definite Mahalanobis matrix. + + Returns + ------- + L : (d x d) matrix + """ + if self.diagonal: + return np.sqrt(self.A_) + else: + w, V = np.linalg.eigh(self.A_) + return V.T * np.sqrt(np.maximum(0, w[:,None])) + + +class MMC_Supervised(MMC): + """Mahalanobis Metric for Clustering (MMC)""" + def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6, + num_labeled=np.inf, num_constraints=None, + A0=None, diagonal=False, diagonal_c=1.0, verbose=False): + """Initialize the learner. + Parameters + ---------- + max_iter : int, optional + max_proj : int, optional + convergence_threshold : float, optional + num_labeled : int, optional + number of labels to preserve for training + num_constraints: int, optional + number of constraints to generate + A0 : (d x d) matrix, optional + initial metric, defaults to identity + only the main diagonal is taken if `diagonal == True` + diagonal : bool, optional + if True, a diagonal metric will be learned, + i.e., a simple scaling of dimensions + diagonal_c : float, optional + weight of the dissimilarity constraint for diagonal + metric learning + verbose : bool, optional + if True, prints information while learning + """ + MMC.__init__(self, max_iter=max_iter, max_proj=max_proj, + convergence_threshold=convergence_threshold, + A0=A0, diagonal=diagonal, diagonal_c=diagonal_c, + verbose=verbose) + self.num_labeled = num_labeled + self.num_constraints = num_constraints + + def fit(self, X, y, random_state=np.random): + """Create constraints from labels and learn the MMC model. + Parameters + ---------- + X : (n x d) matrix + Input data, where each row corresponds to a single instance. + y : (n) array-like + Data labels. + random_state : numpy.random.RandomState, optional + If provided, controls random number generation. + """ + X, y = check_X_y(X, y) + num_constraints = self.num_constraints + if num_constraints is None: + num_classes = len(np.unique(y)) + num_constraints = 20 * num_classes**2 + + c = Constraints.random_subset(y, self.num_labeled, + random_state=random_state) + pos_neg = c.positive_negative_pairs(num_constraints, + random_state=random_state) + return MMC.fit(self, X, pos_neg) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 1e7f31fe..351b6298 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -6,8 +6,8 @@ from numpy.testing import assert_array_almost_equal from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) + LMNN, NCA, LFDA, Covariance, MLKR, MMC, + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) # Import this specially for testing. from metric_learn.lmnn import python_LMNN @@ -149,5 +149,42 @@ def test_iris(self): self.assertLess(csep, 0.25) +class TestMMC(MetricTestCase): + def test_iris(self): + + # Generate full set of constraints for comparison with reference implementation + n = self.iris_points.shape[0] + mask = (self.iris_labels[None] == self.iris_labels[:,None]) + a, b = np.nonzero(np.triu(mask, k=1)) + c, d = np.nonzero(np.triu(~mask, k=1)) + + # Full metric + mmc = MMC(convergence_threshold=0.01) + mmc.fit(self.iris_points, [a,b,c,d]) + expected = [[+0.00046504, +0.00083371, -0.00111959, -0.00165265], + [+0.00083371, +0.00149466, -0.00200719, -0.00296284], + [-0.00111959, -0.00200719, +0.00269546, +0.00397881], + [-0.00165265, -0.00296284, +0.00397881, +0.00587320]] + assert_array_almost_equal(expected, mmc.metric(), decimal=6) + + # Diagonal metric + mmc = MMC(diagonal=True) + mmc.fit(self.iris_points, [a,b,c,d]) + expected = [0, 0, 1.21045968, 1.22552608] + assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6) + + # Supervised Full + mmc = MMC_Supervised() + mmc.fit(self.iris_points, self.iris_labels) + csep = class_separation(mmc.transform(), self.iris_labels) + self.assertLess(csep, 0.15) + + # Supervised Diagonal + mmc = MMC_Supervised(diagonal=True) + mmc.fit(self.iris_points, self.iris_labels) + csep = class_separation(mmc.transform(), self.iris_labels) + self.assertLess(csep, 0.2) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_base_metric.py b/test/test_base_metric.py index d73138cd..31db4e6f 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -63,5 +63,16 @@ def test_mlkr(self): "MLKR(A0=None, alpha=0.0001, epsilon=0.01, " "max_iter=1000, num_dims=None)") + def test_mmc(self): + self.assertEqual(str(metric_learn.MMC()), """ +MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0, + max_iter=100, max_proj=10000, verbose=False) +""".strip('\n')) + self.assertEqual(str(metric_learn.MMC_Supervised()), """ +MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False, + diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None, + num_labeled=inf, verbose=False) +""".strip('\n')) + if __name__ == '__main__': unittest.main() diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index eff8fa01..707815ec 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -5,7 +5,7 @@ from metric_learn import ( LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) class TestFitTransform(unittest.TestCase): @@ -118,6 +118,18 @@ def test_mlkr(self): assert_array_almost_equal(res_1, res_2) + def test_mmc_supervised(self): + seed = np.random.RandomState(1234) + mmc = MMC_Supervised(num_constraints=200) + mmc.fit(self.X, self.y, random_state=seed) + res_1 = mmc.transform() + + seed = np.random.RandomState(1234) + mmc = MMC_Supervised(num_constraints=200) + res_2 = mmc.fit_transform(self.X, self.y, random_state=seed) + + assert_array_almost_equal(res_1, res_2) + if __name__ == '__main__': unittest.main() diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 58c7cd05..f1e1a09d 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -4,7 +4,7 @@ from metric_learn import ( LMNN, NCA, LFDA, Covariance, MLKR, - LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) # Wrap the _Supervised methods with a deterministic wrapper for testing. @@ -22,6 +22,10 @@ class dITML(deterministic_mixin, ITML_Supervised): pass +class dMMC(deterministic_mixin, MMC_Supervised): + pass + + class dSDML(deterministic_mixin, SDML_Supervised): pass @@ -52,6 +56,9 @@ def test_lsml(self): def test_itml(self): check_estimator(dITML) + def test_mmc(self): + check_estimator(dMMC) + # This fails due to a FloatingPointError # def test_sdml(self): # check_estimator(dSDML)