diff --git a/metric_learn/__init__.py b/metric_learn/__init__.py index 5efb9f5c..638d6d4d 100644 --- a/metric_learn/__init__.py +++ b/metric_learn/__init__.py @@ -1,9 +1,10 @@ from __future__ import absolute_import -from .itml import ITML +from .itml import ITML, ITML_Supervised from .lmnn import LMNN -from .lsml import LSML -from .sdml import SDML +from .lsml import LSML, LSML_Supervised +from .sdml import SDML, SDML_Supervised from .nca import NCA from .lfda import LFDA -from .rca import RCA +from .rca import RCA, RCA_Supervised +from .constraints import adjacencyMatrix, positiveNegativePairs, relativeQuadruplets, chunks diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py new file mode 100644 index 00000000..58a2768a --- /dev/null +++ b/metric_learn/constraints.py @@ -0,0 +1,62 @@ +""" Helper class that can generate different types of constraints from supervised data labels.""" + +import numpy as np +import random +from six.moves import xrange + +# @TODO: consider creating a stateful class +# https://github.com/all-umass/metric-learn/pull/19#discussion_r67386226 + +def adjacencyMatrix(labels, num_points, num_constraints): + a, c = np.random.randint(len(labels), size=(2,num_constraints)) + b, d = np.empty((2, num_constraints), dtype=int) + for i,(al,cl) in enumerate(zip(labels[a],labels[c])): + b[i] = random.choice(np.nonzero(labels == al)[0]) + d[i] = random.choice(np.nonzero(labels != cl)[0]) + W = np.zeros((num_points,num_points)) + W[a,b] = 1 + W[c,d] = -1 + # make W symmetric + W[b,a] = 1 + W[d,c] = -1 + return W + +def positiveNegativePairs(labels, num_points, num_constraints): + ac,bd = np.random.randint(num_points, size=(2,num_constraints)) + pos = labels[ac] == labels[bd] + a,c = ac[pos], ac[~pos] + b,d = bd[pos], bd[~pos] + return a,b,c,d + +def relativeQuadruplets(labels, num_constraints): + C = np.empty((num_constraints,4), dtype=int) + a, c = np.random.randint(len(labels), size=(2,num_constraints)) + for i,(al,cl) in enumerate(zip(labels[a],labels[c])): + C[i,1] = random.choice(np.nonzero(labels == al)[0]) + C[i,3] = random.choice(np.nonzero(labels != cl)[0]) + C[:,0] = a + C[:,2] = c + return C + +def chunks(Y, num_chunks=100, chunk_size=2, seed=None): + # @TODO: remove seed from params and use numpy RandomState + # https://github.com/all-umass/metric-learn/pull/19#discussion_r67386666 + random.seed(seed) + chunks = -np.ones_like(Y, dtype=int) + uniq, lookup = np.unique(Y, return_inverse=True) + all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))] + idx = 0 + while idx < num_chunks and all_inds: + c = random.randint(0, len(all_inds)-1) + inds = all_inds[c] + if len(inds) < chunk_size: + del all_inds[c] + continue + ii = random.sample(inds, chunk_size) + inds.difference_update(ii) + chunks[ii] = idx + idx += 1 + if idx < num_chunks: + raise ValueError('Unable to make %d chunks of %d examples each' % + (num_chunks, chunk_size)) + return chunks diff --git a/metric_learn/itml.py b/metric_learn/itml.py index c6ad7e97..95636a9a 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -16,11 +16,12 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from .base_metric import BaseMetricLearner +from .constraints import positiveNegativePairs class ITML(BaseMetricLearner): """Information Theoretic Metric Learning (ITML)""" - def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3): + def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, verbose=False): """Initialize the learner. Parameters @@ -29,11 +30,14 @@ def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3): value for slack variables max_iters : int, optional convergence_threshold : float, optional + verbose : bool, optional + if True, prints information while learning """ self.params = { 'gamma': gamma, 'max_iters': max_iters, 'convergence_threshold': convergence_threshold, + 'verbose': verbose, } def _process_inputs(self, X, constraints, bounds, A0): @@ -57,7 +61,7 @@ def _process_inputs(self, X, constraints, bounds, A0): self.A = A0 return a,b,c,d - def fit(self, X, constraints, bounds=None, A0=None, verbose=False): + def fit(self, X, constraints, bounds=None, A0=None): """Learn the ITML model. Parameters @@ -71,6 +75,7 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False): A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity """ + verbose = self.params['verbose'] a,b,c,d = self._process_inputs(X, constraints, bounds, A0) gamma = self.params['gamma'] conv_thresh = self.params['convergence_threshold'] @@ -121,14 +126,6 @@ def fit(self, X, constraints, bounds=None, A0=None, verbose=False): def metric(self): return self.A - @classmethod - def prepare_constraints(self, labels, num_points, num_constraints): - ac,bd = np.random.randint(num_points, size=(2,num_constraints)) - pos = labels[ac] == labels[bd] - a,c = ac[pos], ac[~pos] - b,d = bd[pos], bd[~pos] - return a,b,c,d - # hack around lack of axis kwarg in older numpy versions try: np.linalg.norm([[4]], axis=1) @@ -138,3 +135,46 @@ def _vector_norm(X): else: def _vector_norm(X): return np.linalg.norm(X, axis=1) + + +class ITML_Supervised(ITML): + """Information Theoretic Metric Learning (ITML)""" + def __init__(self, gamma=1., max_iters=1000, convergence_threshold=1e-3, num_constraints=None, + bounds=None, A0=None, verbose=False): + """Initialize the learner. + + Parameters + ---------- + gamma : float, optional + value for slack variables + max_iters : int, optional + convergence_threshold : float, optional + num_constraints: int, needed for .fit() + verbose : bool, optional + if True, prints information while learning + """ + ITML.__init__(self, gamma=gamma, max_iters=max_iters, + convergence_threshold=convergence_threshold, verbose=verbose) + self.params.update({ + 'num_constraints': num_constraints, + 'bounds': bounds, + 'A0': A0, + }) + + def fit(self, X, labels): + """Create constraints from labels and learn the ITML model. + Needs num_constraints specified in constructor. + + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + labels : (n) data labels + """ + num_constraints = self.params['num_constraints'] + if num_constraints is None: + num_classes = np.unique(labels) + num_constraints = 20*(len(num_classes))**2 + + C = positiveNegativePairs(labels, X.shape[0], num_constraints) + return ITML.fit(self, X, C, bounds=self.params['bounds'], A0=self.params['A0']) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 5552a73b..189b4e83 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -29,7 +29,7 @@ def transformer(self): # slower Python version class python_LMNN(_base_LMNN): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, - regularization=0.5, convergence_tol=0.001): + regularization=0.5, convergence_tol=0.001, verbose=False): """Initialize the LMNN object k: number of neighbors to consider. (does not include self-edges) @@ -37,7 +37,7 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, """ _base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter, learn_rate=learn_rate, regularization=regularization, - convergence_tol=convergence_tol) + convergence_tol=convergence_tol, verbose=verbose) def _process_inputs(self, X, labels): num_pts = X.shape[0] @@ -51,8 +51,9 @@ def _process_inputs(self, X, labels): 'not enough class labels for specified k' ' (smallest class has %d)' % required_k) - def fit(self, X, labels, verbose=False): + def fit(self, X, labels): k = self.params['k'] + verbose = self.params['verbose'] reg = self.params['regularization'] learn_rate = self.params['learn_rate'] convergence_tol = self.params['convergence_tol'] @@ -236,12 +237,12 @@ def _sum_outer_products(data, a_inds, b_inds, weights=None): class LMNN(_base_LMNN): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, - regularization=0.5, convergence_tol=0.001, use_pca=True): + regularization=0.5, convergence_tol=0.001, use_pca=True, verbose=False): _base_LMNN.__init__(self, k=k, min_iter=min_iter, max_iter=max_iter, learn_rate=learn_rate, regularization=regularization, - convergence_tol=convergence_tol, use_pca=use_pca) + convergence_tol=convergence_tol, use_pca=use_pca, verbose=verbose) - def fit(self, X, labels, verbose=False): + def fit(self, X, labels): self.X = X self.L = np.eye(X.shape[1]) labels = MulticlassLabels(labels.astype(np.float64)) diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 108bd064..f3bf9738 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -13,20 +13,24 @@ from random import choice from six.moves import xrange from .base_metric import BaseMetricLearner +from .constraints import relativeQuadruplets class LSML(BaseMetricLearner): - def __init__(self, tol=1e-3, max_iter=1000): + def __init__(self, tol=1e-3, max_iter=1000, verbose=False): """Initialize the learner. Parameters ---------- tol : float, optional max_iter : int, optional + verbose : bool, optional + if True, prints information while learning """ self.params = { 'tol': tol, 'max_iter': max_iter, + 'verbose': verbose, } def _prepare_inputs(self, X, constraints, weights, prior): @@ -46,7 +50,7 @@ def _prepare_inputs(self, X, constraints, weights, prior): def metric(self): return self.M - def fit(self, X, constraints, weights=None, prior=None, verbose=False): + def fit(self, X, constraints, weights=None, prior=None): """Learn the LSML model. Parameters @@ -59,9 +63,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False): scale factor for each constraint prior : (d x d) matrix, optional guess at a metric [default: covariance(X)] - verbose : bool, optional - if True, prints information while learning """ + verbose = self.params['verbose'] self._prepare_inputs(X, constraints, weights, prior) prior_inv = scipy.linalg.inv(self.M) s_best = self._total_loss(self.M, prior_inv) @@ -93,7 +96,8 @@ def fit(self, X, constraints, weights=None, prior=None, verbose=False): break self.M = M_best else: - print("Didn't converge after", it, "iterations. Final loss:", s_best) + if verbose: + print("Didn't converge after", it, "iterations. Final loss:", s_best) return self def _comparison_loss(self, metric): @@ -119,18 +123,47 @@ def _gradient(self, metric, prior_inv): (1-np.sqrt(dab/dcd))*np.outer(vcd, vcd)) return dMetric - @classmethod - def prepare_constraints(cls, labels, num_constraints): - C = np.empty((num_constraints,4), dtype=int) - a, c = np.random.randint(len(labels), size=(2,num_constraints)) - for i,(al,cl) in enumerate(zip(labels[a],labels[c])): - C[i,1] = choice(np.nonzero(labels == al)[0]) - C[i,3] = choice(np.nonzero(labels != cl)[0]) - C[:,0] = a - C[:,2] = c - return C - - def _regularization_loss(metric, prior_inv): sign, logdet = np.linalg.slogdet(metric) return np.sum(metric * prior_inv) - sign * logdet + +class LSML_Supervised(LSML): + def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_constraints=None, weights=None, verbose=False): + """Initialize the learner. + + Parameters + ---------- + tol : float, optional + max_iter : int, optional + prior : (d x d) matrix, optional + guess at a metric [default: covariance(X)] + num_constraints: int, needed for .fit() + weights : (m,) array of floats, optional + scale factor for each constraint + verbose : bool, optional + if True, prints information while learning + """ + LSML.__init__(self, tol=tol, max_iter=max_iter, verbose=verbose) + self.params.update({ + 'prior': prior, + 'num_constraints': num_constraints, + 'weights': weights, + }) + + def fit(self, X, labels): + """Create constraints from labels and learn the LSML model. + Needs num_constraints specified in constructor. + + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + labels : (n) data labels + """ + num_constraints = self.params['num_constraints'] + if num_constraints is None: + num_classes = np.unique(labels) + num_constraints = 20*(len(num_classes))**2 + + C = relativeQuadruplets(labels, num_constraints) + return LSML.fit(self, X, C, weights=self.params['weights'], prior=self.params['prior']) diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 818144fb..d76ef21a 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -15,6 +15,7 @@ import random from six.moves import xrange from .base_metric import BaseMetricLearner +from .constraints import chunks class RCA(BaseMetricLearner): @@ -26,6 +27,9 @@ def __init__(self, dim=None): ---------- dim : int, optional embedding dimension (default: original dimension of data) + num_chunks: int, optional + chunk_size: int, optional + seed: int, optional """ self.params = { 'dim': dim, @@ -88,30 +92,44 @@ def fit(self, data, chunks): else: self._transformer = _inv_sqrtm(inner_cov).T - @classmethod - def prepare_constraints(cls, Y, num_chunks=100, chunk_size=2, seed=None): - random.seed(seed) - chunks = -np.ones_like(Y, dtype=int) - uniq, lookup = np.unique(Y, return_inverse=True) - all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))] - idx = 0 - while idx < num_chunks and all_inds: - c = random.randint(0, len(all_inds)-1) - inds = all_inds[c] - if len(inds) < chunk_size: - del all_inds[c] - continue - ii = random.sample(inds, chunk_size) - inds.difference_update(ii) - chunks[ii] = idx - idx += 1 - if idx < num_chunks: - raise ValueError('Unable to make %d chunks of %d examples each' % - (num_chunks, chunk_size)) - return chunks - + return self def _inv_sqrtm(x): '''Computes x^(-1/2)''' vals, vecs = np.linalg.eigh(x) return (vecs / np.sqrt(vals)).dot(vecs.T) + + +class RCA_Supervised(RCA): + """Relevant Components Analysis (RCA)""" + def __init__(self, dim=None, num_chunks=None, chunk_size=None, seed=None): + """Initialize the learner. + + Parameters + ---------- + dim : int, optional + embedding dimension (default: original dimension of data) + num_chunks: int, optional + chunk_size: int, optional + seed: int, optional + """ + # @TODO: remove seed from param. See @TODO in constraints/chunks + RCA.__init__(self, dim=dim) + self.params.update({ + 'num_chunks': 100 if num_chunks is None else num_chunks, + 'chunk_size': 2 if chunk_size is None else chunk_size, + 'seed': seed, + }) + + def fit(self, X, labels): + """Create constraints from labels and learn the LSML model. + Needs num_constraints specified in constructor. + + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + labels : (n) data labels + """ + C = chunks(labels, self.params['num_chunks'], self.params['chunk_size'], self.params['seed']) + return RCA.fit(self, X, C) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 5794bcfc..c99f8214 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -15,19 +15,23 @@ from sklearn.covariance import graph_lasso from sklearn.utils.extmath import pinvh from .base_metric import BaseMetricLearner +from .constraints import adjacencyMatrix class SDML(BaseMetricLearner): - def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True): + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, verbose=False): ''' balance_param: trade off between sparsity and M0 prior sparsity_param: trade off between optimizer and sparseness (see graph_lasso) use_cov: controls prior matrix, will use the identity if use_cov=False + verbose : bool, optional + if True, prints information while learning ''' self.params = { 'balance_param': balance_param, 'sparsity_param': sparsity_param, 'use_cov': use_cov, + 'verbose': verbose, } def _prepare_inputs(self, X, W): @@ -43,7 +47,7 @@ def _prepare_inputs(self, X, W): def metric(self): return self.M - def fit(self, X, W, verbose=False): + def fit(self, X, W): """ X: data matrix, (n x d) W: connectivity graph, (n x n). +1 for positive pairs, -1 for negative. @@ -54,20 +58,36 @@ def fit(self, X, W, verbose=False): # hack: ensure positive semidefinite emp_cov = emp_cov.T.dot(emp_cov) self.M, _ = graph_lasso(emp_cov, self.params['sparsity_param'], - verbose=verbose) + verbose=self.params['verbose']) return self - @classmethod - def prepare_constraints(self, labels, num_points, num_constraints): - a, c = np.random.randint(len(labels), size=(2,num_constraints)) - b, d = np.empty((2, num_constraints), dtype=int) - for i,(al,cl) in enumerate(zip(labels[a],labels[c])): - b[i] = choice(np.nonzero(labels == al)[0]) - d[i] = choice(np.nonzero(labels != cl)[0]) - W = np.zeros((num_points,num_points)) - W[a,b] = 1 - W[c,d] = -1 - # make W symmetric - W[b,a] = 1 - W[d,c] = -1 - return W +class SDML_Supervised(SDML): + def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, num_constraints=None, verbose=False): + ''' + balance_param: trade off between sparsity and M0 prior + sparsity_param: trade off between optimizer and sparseness (see graph_lasso) + use_cov: controls prior matrix, will use the identity if use_cov=False + num_constraints: int, needed for .fit() + verbose : bool, optional + if True, prints information while learning + ''' + SDML.__init__(self, balance_param=balance_param, sparsity_param=sparsity_param, use_cov=use_cov, verbose=verbose) + self.params['num_constraints'] = num_constraints + + def fit(self, X, labels): + """Create constraints from labels and learn the SDML model. + Needs num_constraints specified in constructor. + + Parameters + ---------- + X : (n x d) data matrix + each row corresponds to a single instance + labels : (n) data labels + """ + num_constraints = self.params['num_constraints'] + if num_constraints is None: + num_classes = np.unique(labels) + num_constraints = 20*(len(num_classes))**2 + + W = adjacencyMatrix(labels, X.shape[0], num_constraints) + return SDML.fit(self, X, W) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index d431db8d..5cc1e56c 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -6,7 +6,8 @@ from sklearn.datasets import load_iris from numpy.testing import assert_array_almost_equal -from metric_learn import LSML, ITML, LMNN, SDML, NCA, LFDA, RCA +from metric_learn import LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised +from metric_learn import LMNN, NCA, LFDA # Import this specially for testing. from metric_learn.lmnn import python_LMNN @@ -35,8 +36,7 @@ class TestLSML(MetricTestCase): def test_iris(self): num_constraints = 200 - C = LSML.prepare_constraints(self.iris_labels, num_constraints) - lsml = LSML().fit(self.iris_points, C, verbose=False) + lsml = LSML_Supervised(num_constraints=num_constraints).fit(self.iris_points, self.iris_labels) csep = class_separation(lsml.transform(), self.iris_labels) self.assertLess(csep, 0.8) # it's pretty terrible @@ -46,9 +46,7 @@ class TestITML(MetricTestCase): def test_iris(self): num_constraints = 200 - n = self.iris_points.shape[0] - C = ITML.prepare_constraints(self.iris_labels, n, num_constraints) - itml = ITML().fit(self.iris_points, C, verbose=False) + itml = ITML_Supervised(num_constraints=num_constraints).fit(self.iris_points, self.iris_labels) csep = class_separation(itml.transform(), self.iris_labels) self.assertLess(csep, 0.4) # it's not great @@ -60,8 +58,8 @@ def test_iris(self): # Test both impls, if available. for LMNN_cls in set((LMNN, python_LMNN)): - lmnn = LMNN_cls(k=k, learn_rate=1e-6) - lmnn.fit(self.iris_points, self.iris_labels, verbose=False) + lmnn = LMNN_cls(k=k, learn_rate=1e-6, verbose=False) + lmnn.fit(self.iris_points, self.iris_labels) csep = class_separation(lmnn.transform(), self.iris_labels) self.assertLess(csep, 0.25) @@ -71,17 +69,13 @@ class TestSDML(MetricTestCase): def test_iris(self): num_constraints = 1500 - n = self.iris_points.shape[0] # Note: this is a flaky test, which fails for certain seeds. # TODO: un-flake it! np.random.seed(5555) - W = SDML.prepare_constraints(self.iris_labels, n, num_constraints) - # Test sparse graph inputs. - for graph in ((W, scipy.sparse.csr_matrix(W))): - sdml = SDML().fit(self.iris_points, graph) - csep = class_separation(sdml.transform(), self.iris_labels) - self.assertLess(csep, 0.25) + sdml = SDML_Supervised(num_constraints=num_constraints).fit(self.iris_points, self.iris_labels) + csep = class_separation(sdml.transform(), self.iris_labels) + self.assertLess(csep, 0.25) class TestNCA(MetricTestCase): @@ -109,10 +103,8 @@ def test_iris(self): class TestRCA(MetricTestCase): def test_iris(self): - rca = RCA(dim=2) - chunks = RCA.prepare_constraints(self.iris_labels, num_chunks=30, - chunk_size=2, seed=1234) - rca.fit(self.iris_points, chunks) + rca = RCA_Supervised(dim=2, num_chunks=30, chunk_size=2, seed=1234) + rca.fit(self.iris_points, self.iris_labels) csep = class_separation(rca.transform(), self.iris_labels) self.assertLess(csep, 0.25)