From 9f5c9987726d8482943b074ee753dcdfba5892f9 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Mon, 14 May 2018 15:55:12 +0200 Subject: [PATCH 1/6] Update API to be compatible with scikit-learn by taking 3D inputs for Weakly Supervised Algorithms. --- metric_learn/constraints.py | 10 ++++ metric_learn/itml.py | 44 ++++++++------ metric_learn/lsml.py | 23 ++++---- metric_learn/mmc.py | 115 +++++++++++++++++++----------------- metric_learn/sdml.py | 28 +++++---- test/metric_learn_test.py | 15 ++--- test/test_fit_transform.py | 8 +-- 7 files changed, 139 insertions(+), 104 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 8824450a..3986fce8 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -100,3 +100,13 @@ def random_subset(all_labels, num_preserved=np.inf, random_state=np.random): partial_labels = np.array(all_labels, copy=True) partial_labels[idx] = -1 return Constraints(partial_labels) + +def wrap_pairs(X, constraints): + a = np.array(constraints[0]) + b = np.array(constraints[1]) + c = np.array(constraints[2]) + d = np.array(constraints[3]) + constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) + y = np.vstack([np.ones((len(a), 1)), np.zeros((len(c), 1))]) + pairs = X[constraints] + return pairs, y \ No newline at end of file diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 4d27c412..3992bfbb 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -20,7 +20,7 @@ from sklearn.utils.validation import check_array, check_X_y from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .constraints import Constraints, wrap_pairs from ._util import vector_norm @@ -51,16 +51,20 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, self.A0 = A0 self.verbose = verbose - def _process_inputs(self, X, constraints, bounds): - self.X_ = X = check_array(X) + def _process_pairs(self, pairs, y, bounds): + pairs, y = check_X_y(pairs, y, accept_sparse=False, + ensure_2d=False, allow_nd=True) + y = y.astype(bool) + # check to make sure that no two constrained vectors are identical - a,b,c,d = constraints - no_ident = vector_norm(X[a] - X[b]) > 1e-9 - a, b = a[no_ident], b[no_ident] - no_ident = vector_norm(X[c] - X[d]) > 1e-9 - c, d = c[no_ident], d[no_ident] + pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 + pos_pairs = pos_pairs[pos_no_ident] + neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 + neg_pairs = neg_pairs[neg_no_ident] # init bounds if bounds is None: + X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) else: assert len(bounds) == 2 @@ -68,12 +72,16 @@ def _process_inputs(self, X, constraints, bounds): self.bounds_[self.bounds_==0] = 1e-9 # init metric if self.A0 is None: - self.A_ = np.identity(X.shape[1]) + self.A_ = np.identity(pairs.shape[2]) else: self.A_ = check_array(self.A0) - return a,b,c,d + pairs = np.vstack([pos_pairs, neg_pairs]) + y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))]) + y = y.astype(bool) + return pairs, y + - def fit(self, X, constraints, bounds=None): + def fit(self, pairs, y, bounds=None): """Learn the ITML model. Parameters @@ -86,17 +94,18 @@ def fit(self, X, constraints, bounds=None): bounds : list (pos,neg) pairs, optional bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg """ - a,b,c,d = self._process_inputs(X, constraints, bounds) + pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma - num_pos = len(a) - num_neg = len(c) + pos_pairs, neg_pairs = pairs[y], pairs[~y] + num_pos = len(pos_pairs) + num_neg = len(neg_pairs) _lambda = np.zeros(num_pos + num_neg) lambdaold = np.zeros_like(_lambda) gamma_proj = 1. if gamma is np.inf else gamma/(gamma+1.) pos_bhat = np.zeros(num_pos) + self.bounds_[0] neg_bhat = np.zeros(num_neg) + self.bounds_[1] - pos_vv = self.X_[a] - self.X_[b] - neg_vv = self.X_[c] - self.X_[d] + pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] + neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] A = self.A_ for it in xrange(self.max_iter): @@ -195,4 +204,5 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return ITML.fit(self, X, pos_neg, bounds=self.bounds) + pairs, y = wrap_pairs(X, pos_neg) + return ITML.fit(self, pairs, y, bounds=self.bounds) diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 404fe286..51f4ef48 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -14,7 +14,7 @@ from sklearn.utils.validation import check_array, check_X_y from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .constraints import Constraints, wrap_pairs class LSML(BaseMetricLearner): @@ -35,11 +35,13 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False): self.max_iter = max_iter self.verbose = verbose - def _prepare_inputs(self, X, constraints, weights): - self.X_ = X = check_array(X) - a,b,c,d = constraints - self.vab_ = X[a] - X[b] - self.vcd_ = X[c] - X[d] + def _prepare_quadruplets(self, quadruplets, weights): + pairs = check_array(quadruplets, accept_sparse=False, + ensure_2d=False, allow_nd=True) + + # check to make sure that no two constrained vectors are identical + self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] + self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :] if self.vab_.shape != self.vcd_.shape: raise ValueError('Constraints must have same length') if weights is None: @@ -48,6 +50,7 @@ def _prepare_inputs(self, X, constraints, weights): self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: + X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) self.M_ = np.linalg.inv(self.prior_inv_) else: @@ -57,7 +60,7 @@ def _prepare_inputs(self, X, constraints, weights): def metric(self): return self.M_ - def fit(self, X, constraints, weights=None): + def fit(self, quadruplets, weights=None): """Learn the LSML model. Parameters @@ -69,7 +72,7 @@ def fit(self, X, constraints, weights=None): weights : (m,) array of floats, optional scale factor for each constraint """ - self._prepare_inputs(X, constraints, weights) + self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. l_best = 0 @@ -179,6 +182,6 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) - pairs = c.positive_negative_pairs(num_constraints, same_length=True, + pos_neg = c.positive_negative_pairs(num_constraints, same_length=True, random_state=random_state) - return LSML.fit(self, X, pairs, weights=self.weights) + return LSML.fit(self, X[np.column_stack(pos_neg)], weights=self.weights) diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index efe33c38..39435dac 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -23,7 +23,7 @@ from sklearn.utils.validation import check_array, check_X_y from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .constraints import Constraints, wrap_pairs from ._util import vector_norm @@ -58,7 +58,8 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-3, self.diagonal_c = diagonal_c self.verbose = verbose - def fit(self, X, constraints): + + def fit(self, pairs, y): """Learn the MMC model. Parameters @@ -69,30 +70,31 @@ def fit(self, X, constraints): (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) dissimilar pairs """ - constraints = self._process_inputs(X, constraints) + pairs, y = self._process_pairs(pairs, y) if self.diagonal: - return self._fit_diag(X, constraints) + return self._fit_diag(pairs, y) else: - return self._fit_full(X, constraints) - - def _process_inputs(self, X, constraints): + return self._fit_full(pairs, y) - self.X_ = X = check_array(X) + def _process_pairs(self, pairs, y): + pairs, y = check_X_y(pairs, y, accept_sparse=False, + ensure_2d=False, allow_nd=True) + y = y.astype(bool) # check to make sure that no two constrained vectors are identical - a,b,c,d = constraints - no_ident = vector_norm(X[a] - X[b]) > 1e-9 - a, b = a[no_ident], b[no_ident] - no_ident = vector_norm(X[c] - X[d]) > 1e-9 - c, d = c[no_ident], d[no_ident] - if len(a) == 0: + pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 + pos_pairs = pos_pairs[pos_no_ident] + neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 + neg_pairs = neg_pairs[neg_no_ident] + if len(pos_pairs) == 0: raise ValueError('No non-trivial similarity constraints given for MMC.') - if len(c) == 0: + if len(neg_pairs) == 0: raise ValueError('No non-trivial dissimilarity constraints given for MMC.') # init metric if self.A0 is None: - self.A_ = np.identity(X.shape[1]) + self.A_ = np.identity(pairs.shape[2]) if not self.diagonal: # Don't know why division by 10... it's in the original code # and seems to affect the overall scale of the learned metric. @@ -100,9 +102,12 @@ def _process_inputs(self, X, constraints): else: self.A_ = check_array(self.A0) - return a,b,c,d + pairs = np.vstack([pos_pairs, neg_pairs]) + y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))]) + y = y.astype(bool) + return pairs, y - def _fit_full(self, X, constraints): + def _fit_full(self, pairs, y): """Learn full metric using MMC. Parameters @@ -113,17 +118,16 @@ def _fit_full(self, X, constraints): (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) dissimilar pairs """ - a,b,c,d = constraints - num_pos = len(a) - num_neg = len(c) - num_samples, num_dim = X.shape + num_dim = pairs.shape[2] error1 = error2 = 1e10 eps = 0.01 # error-bound of iterative projection on C1 and C2 A = self.A_ + pos_pairs, neg_pairs = pairs[y], pairs[~y] + # Create weight vector from similar samples - pos_diff = X[a] - X[b] + pos_diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] w = np.einsum('ij,ik->jk', pos_diff, pos_diff).ravel() # `w` is the sum of all outer products of the rows in `pos_diff`. # The above `einsum` is equivalent to the much more inefficient: @@ -140,9 +144,10 @@ def _fit_full(self, X, constraints): cycle = 1 alpha = 0.1 # initial step size along gradient - - grad1 = self._fS1(X, a, b, A) # gradient of similarity constraint function - grad2 = self._fD1(X, c, d, A) # gradient of dissimilarity constraint function + grad1 = self._fS1(pos_pairs, A) # gradient of similarity + # constraint function + grad2 = self._fD1(neg_pairs, A) # gradient of dissimilarity + # constraint function M = self._grad_projection(grad1, grad2) # gradient of fD1 orthogonal to fS1 A_old = A.copy() @@ -183,8 +188,8 @@ def _fit_full(self, X, constraints): # max: g(A) >= 1 # here we suppose g(A) = fD(A) = \sum_{I,J \in D} sqrt(d_ij' A d_ij) - obj_previous = self._fD(X, c, d, A_old) # g(A_old) - obj = self._fD(X, c, d, A) # g(A) + obj_previous = self._fD(neg_pairs, A_old) # g(A_old) + obj = self._fD(neg_pairs, A) # g(A) if satisfy and (obj > obj_previous or cycle == 0): @@ -193,8 +198,8 @@ def _fit_full(self, X, constraints): # and update from the current A. alpha *= 1.05 A_old[:] = A - grad2 = self._fS1(X, a, b, A) - grad1 = self._fD1(X, c, d, A) + grad2 = self._fS1(pos_pairs, A) + grad1 = self._fD1(neg_pairs, A) M = self._grad_projection(grad1, grad2) A += alpha * M @@ -224,7 +229,7 @@ def _fit_full(self, X, constraints): self.n_iter_ = cycle return self - def _fit_diag(self, X, constraints): + def _fit_diag(self, pairs, y): """Learn diagonal metric using MMC. Parameters ---------- @@ -234,12 +239,9 @@ def _fit_diag(self, X, constraints): (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) dissimilar pairs """ - a,b,c,d = constraints - num_pos = len(a) - num_neg = len(c) - num_samples, num_dim = X.shape - - s_sum = np.sum((X[a] - X[b]) ** 2, axis=0) + num_dim = pairs.shape[2] + pos_pairs, neg_pairs = pairs[y], pairs[~y] + s_sum = np.sum((pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) ** 2, axis=0) it = 0 error = 1.0 @@ -249,20 +251,21 @@ def _fit_diag(self, X, constraints): while error > self.convergence_threshold and it < self.max_iter: - fD0, fD_1st_d, fD_2nd_d = self._D_constraint(X, c, d, w) + fD0, fD_1st_d, fD_2nd_d = self._D_constraint(neg_pairs, w) obj_initial = np.dot(s_sum, w) + self.diagonal_c * fD0 fS_1st_d = s_sum # first derivative of the similarity constraints gradient = fS_1st_d - self.diagonal_c * fD_1st_d # gradient of the objective hessian = -self.diagonal_c * fD_2nd_d + eps * np.eye(num_dim) # Hessian of the objective - step = np.dot(np.linalg.inv(hessian), gradient); + step = np.dot(np.linalg.inv(hessian), gradient) # Newton-Rapshon update # search over optimal lambda lambd = 1 # initial step-size w_tmp = np.maximum(0, w - lambd * step) - obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + obj = np.dot(s_sum, w_tmp) + self.diagonal_c * \ + self._D_objective(neg_pairs, w_tmp) obj_previous = obj * 1.1 # just to get the while-loop started inner_it = 0 @@ -271,7 +274,8 @@ def _fit_diag(self, X, constraints): w_previous = w_tmp.copy() lambd /= reduction w_tmp = np.maximum(0, w - lambd * step) - obj = np.dot(s_sum, w_tmp) + self.diagonal_c * self._D_objective(X, c, d, w_tmp) + obj = np.dot(s_sum, w_tmp) + self.diagonal_c * \ + self._D_objective(neg_pairs, w_tmp) inner_it += 1 w[:] = w_previous @@ -283,16 +287,16 @@ def _fit_diag(self, X, constraints): self.A_ = np.diag(w) return self - def _fD(self, X, c, d, A): + def _fD(self, neg_pairs, A): """The value of the dissimilarity constraint function. f = f(\sum_{ij \in D} distance(x_i, x_j)) i.e. distance can be L1: \sqrt{(x_i-x_j)A(x_i-x_j)'} """ - diff = X[c] - X[d] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] return np.log(np.sum(np.sqrt(np.sum(np.dot(diff, A) * diff, axis=1))) + 1e-6) - def _fD1(self, X, c, d, A): + def _fD1(self, neg_pairs, A): """The gradient of the dissimilarity constraint function w.r.t. A. For example, let distance by L1 norm: @@ -304,8 +308,8 @@ def _fD1(self, X, c, d, A): df/dA = f'(\sum_{ij \in D} \sqrt{tr(d_ij'*d_ij*A)}) * 0.5*(\sum_{ij \in D} (1/sqrt{tr(d_ij'*d_ij*A)})*(d_ij'*d_ij)) """ - dim = X.shape[1] - diff = X[c] - X[d] + dim = neg_pairs.shape[2] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] # outer products of all rows in `diff` M = np.einsum('ij,ik->ijk', diff, diff) # faster version of: dist = np.sqrt(np.sum(M * A[None,:,:], axis=(1,2))) @@ -315,7 +319,7 @@ def _fD1(self, X, c, d, A): sum_dist = dist.sum() return sum_deri / (sum_dist + 1e-6) - def _fS1(self, X, a, b, A): + def _fS1(self, pos_pairs, A): """The gradient of the similarity constraint function w.r.t. A. f = \sum_{ij}(x_i-x_j)A(x_i-x_j)' = \sum_{ij}d_ij*A*d_ij' @@ -324,8 +328,8 @@ def _fS1(self, X, a, b, A): Note that d_ij*A*d_ij' = tr(d_ij*A*d_ij') = tr(d_ij'*d_ij*A) so, d(d_ij*A*d_ij')/dA = d_ij'*d_ij """ - dim = X.shape[1] - diff = X[a] - X[b] + dim = pos_pairs.shape[2] + diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] return np.einsum('ij,ik->jk', diff, diff) # sum of outer products of all rows in `diff` def _grad_projection(self, grad1, grad2): @@ -334,15 +338,17 @@ def _grad_projection(self, grad1, grad2): gtemp /= np.linalg.norm(gtemp) return gtemp - def _D_objective(self, X, c, d, w): - return np.log(np.sum(np.sqrt(np.sum(((X[c] - X[d]) ** 2) * w[None,:], axis=1) + 1e-6))) + def _D_objective(self, neg_pairs, w): + return np.log(np.sum(np.sqrt(np.sum(((neg_pairs[:, 0, :] - + neg_pairs[:, 1, :]) ** 2) * + w[None,:], axis=1) + 1e-6))) - def _D_constraint(self, X, c, d, w): + def _D_constraint(self, neg_pairs, w): """Compute the value, 1st derivative, second derivative (Hessian) of a dissimilarity constraint function gF(sum_ij distance(d_ij A d_ij)) where A is a diagonal matrix (in the form of a column vector 'w'). """ - diff = X[c] - X[d] + diff = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] diff_sq = diff * diff dist = np.sqrt(diff_sq.dot(w)) sum_deri1 = np.einsum('ij,i', diff_sq, 0.5 / np.maximum(dist, 1e-6)) @@ -437,4 +443,5 @@ def fit(self, X, y, random_state=np.random): random_state=random_state) pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) - return MMC.fit(self, X, pos_neg) + pairs, y = wrap_pairs(X, pos_neg) + return MMC.fit(self, pairs, y) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 93280334..dd8c95da 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -13,10 +13,10 @@ from scipy.sparse.csgraph import laplacian from sklearn.covariance import graph_lasso from sklearn.utils.extmath import pinvh -from sklearn.utils.validation import check_array +from sklearn.utils.validation import check_array, check_X_y from .base_metric import BaseMetricLearner -from .constraints import Constraints +from .constraints import Constraints, wrap_pairs class SDML(BaseMetricLearner): @@ -42,21 +42,22 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, self.use_cov = use_cov self.verbose = verbose - def _prepare_inputs(self, X, W): - self.X_ = X = check_array(X) - W = check_array(W, accept_sparse=True) + def _prepare_pairs(self, pairs, y): + pairs, y = check_X_y(pairs, y, accept_sparse=False, + ensure_2d=False, allow_nd=True) # set up prior M if self.use_cov: + X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) self.M_ = pinvh(np.cov(X, rowvar = False)) else: - self.M_ = np.identity(X.shape[1]) - L = laplacian(W, normed=False) - return X.T.dot(L.dot(X)) + self.M_ = np.identity(pairs.shape[2]) + diff = pairs[:, 0] - pairs[:, 1] + return (diff.T * y).dot(diff) def metric(self): return self.M_ - def fit(self, X, W): + def fit(self, pairs, y): """Learn the SDML model. Parameters @@ -71,7 +72,7 @@ def fit(self, X, W): self : object Returns the instance. """ - loss_matrix = self._prepare_inputs(X, W) + loss_matrix = self._prepare_pairs(pairs, y) P = self.M_ + self.balance_param * loss_matrix emp_cov = pinvh(P) # hack: ensure positive semidefinite @@ -131,5 +132,8 @@ def fit(self, X, y, random_state=np.random): c = Constraints.random_subset(y, self.num_labeled, random_state=random_state) - adj = c.adjacency_matrix(num_constraints, random_state=random_state) - return SDML.fit(self, X, adj) + pos_neg = c.positive_negative_pairs(num_constraints, + random_state=random_state) + pairs, y = wrap_pairs(X, pos_neg) + y = 2 * y - 1 + return SDML.fit(self, pairs, y) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 6d78c657..1756b105 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -9,6 +9,7 @@ LMNN, NCA, LFDA, Covariance, MLKR, MMC, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) # Import this specially for testing. +from metric_learn.constraints import wrap_pairs from metric_learn.lmnn import python_LMNN @@ -47,7 +48,7 @@ def test_iris(self): lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.iris_points, self.iris_labels) - csep = class_separation(lsml.transform(), self.iris_labels) + csep = class_separation(lsml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.8) # it's pretty terrible @@ -56,7 +57,7 @@ def test_iris(self): itml = ITML_Supervised(num_constraints=200) itml.fit(self.iris_points, self.iris_labels) - csep = class_separation(itml.transform(), self.iris_labels) + csep = class_separation(itml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) @@ -79,7 +80,7 @@ def test_iris(self): sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.iris_points, self.iris_labels, random_state=rs) - csep = class_separation(sdml.transform(), self.iris_labels) + csep = class_separation(sdml.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) @@ -160,7 +161,7 @@ def test_iris(self): # Full metric mmc = MMC(convergence_threshold=0.01) - mmc.fit(self.iris_points, [a,b,c,d]) + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) expected = [[+0.00046504, +0.00083371, -0.00111959, -0.00165265], [+0.00083371, +0.00149466, -0.00200719, -0.00296284], [-0.00111959, -0.00200719, +0.00269546, +0.00397881], @@ -169,20 +170,20 @@ def test_iris(self): # Diagonal metric mmc = MMC(diagonal=True) - mmc.fit(self.iris_points, [a,b,c,d]) + mmc.fit(*wrap_pairs(self.iris_points, [a,b,c,d])) expected = [0, 0, 1.21045968, 1.22552608] assert_array_almost_equal(np.diag(expected), mmc.metric(), decimal=6) # Supervised Full mmc = MMC_Supervised() mmc.fit(self.iris_points, self.iris_labels) - csep = class_separation(mmc.transform(), self.iris_labels) + csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # Supervised Diagonal mmc = MMC_Supervised(diagonal=True) mmc.fit(self.iris_points, self.iris_labels) - csep = class_separation(mmc.transform(), self.iris_labels) + csep = class_separation(mmc.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.2) diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 707815ec..d239ec95 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -30,7 +30,7 @@ def test_lsml_supervised(self): seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) lsml.fit(self.X, self.y, random_state=seed) - res_1 = lsml.transform() + res_1 = lsml.transform(self.X) seed = np.random.RandomState(1234) lsml = LSML_Supervised(num_constraints=200) @@ -42,7 +42,7 @@ def test_itml_supervised(self): seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) itml.fit(self.X, self.y, random_state=seed) - res_1 = itml.transform() + res_1 = itml.transform(self.X) seed = np.random.RandomState(1234) itml = ITML_Supervised(num_constraints=200) @@ -64,7 +64,7 @@ def test_sdml_supervised(self): seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) sdml.fit(self.X, self.y, random_state=seed) - res_1 = sdml.transform() + res_1 = sdml.transform(self.X) seed = np.random.RandomState(1234) sdml = SDML_Supervised(num_constraints=1500) @@ -122,7 +122,7 @@ def test_mmc_supervised(self): seed = np.random.RandomState(1234) mmc = MMC_Supervised(num_constraints=200) mmc.fit(self.X, self.y, random_state=seed) - res_1 = mmc.transform() + res_1 = mmc.transform(self.X) seed = np.random.RandomState(1234) mmc = MMC_Supervised(num_constraints=200) From 4c887d7d6486760d919642b1cd741086dbbbb007 Mon Sep 17 00:00:00 2001 From: William de Vazelhes <31916524+wdevazelhes@users.noreply.github.com> Date: Fri, 18 May 2018 17:48:58 +0200 Subject: [PATCH 2/6] Deals with scipy's new version, where eigsh can call eigh. (#94) --- metric_learn/lfda.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index dbe5aa4f..809f092b 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -139,10 +139,11 @@ def _sum_outer(x): def _eigh(a, b, dim): try: return scipy.sparse.linalg.eigsh(a, k=dim, M=b, which='LA') - except (ValueError, scipy.sparse.linalg.ArpackNoConvergence): - pass - try: - return scipy.linalg.eigh(a, b) except np.linalg.LinAlgError: - pass + pass # scipy already tried eigh for us + except (ValueError, scipy.sparse.linalg.ArpackNoConvergence): + try: + return scipy.linalg.eigh(a, b) + except np.linalg.LinAlgError: + pass return scipy.linalg.eig(a, b) From a7e480720f2ca31bb88e2dd5cf1ddd5315156773 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 22 May 2018 09:51:37 +0200 Subject: [PATCH 3/6] find unique rows in a way compatible with numpy 1.12.1 --- metric_learn/itml.py | 2 +- metric_learn/lsml.py | 2 +- metric_learn/sdml.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 3992bfbb..fcfebaee 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -64,7 +64,7 @@ def _process_pairs(self, pairs, y, bounds): neg_pairs = neg_pairs[neg_no_ident] # init bounds if bounds is None: - X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.bounds_ = np.percentile(pairwise_distances(X), (5, 95)) else: assert len(bounds) == 2 diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 51f4ef48..5fc418e0 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -50,7 +50,7 @@ def _prepare_quadruplets(self, quadruplets, weights): self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: - X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) self.M_ = np.linalg.inv(self.prior_inv_) else: diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index dd8c95da..213e2904 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -47,7 +47,7 @@ def _prepare_pairs(self, pairs, y): ensure_2d=False, allow_nd=True) # set up prior M if self.use_cov: - X = np.unique(pairs.reshape(-1, pairs.shape[2]), axis=0) + X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) self.M_ = pinvh(np.cov(X, rowvar = False)) else: self.M_ = np.identity(pairs.shape[2]) From 903f17490486ac105baf967f4caa572f4445c1fe Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Tue, 22 May 2018 10:45:44 +0200 Subject: [PATCH 4/6] Update docstring for new api --- metric_learn/itml.py | 14 +++++++++----- metric_learn/lsml.py | 16 +++++++++++----- metric_learn/mmc.py | 14 +++++++++----- metric_learn/sdml.py | 8 ++++---- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index fcfebaee..3d9aff2a 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -86,13 +86,17 @@ def fit(self, pairs, y, bounds=None): Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying positive and (c,d) - negative pairs + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. bounds : list (pos,neg) pairs, optional bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + + Returns + ------- + self : object + Returns the instance. """ pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 5fc418e0..b8b69f19 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -65,12 +65,18 @@ def fit(self, quadruplets, weights=None): Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, such that d(X[a],X[b]) < d(X[c],X[d]) - weights : (m,) array of floats, optional + quadruplets : array-like, shape=(n_constraints, 4, n_features) + Each row corresponds to 4 points. In order to supervise the + algorithm in the right way, we should have the four samples ordered + in a way such that: d(pairs[i, 0],X[i, 1]) < d(X[i, 2], X[i, 3]) + for all 0 <= i < n_constraints. + weights : (n_constraints,) array of floats, optional scale factor for each constraint + + Returns + ------- + self : object + Returns the instance. """ self._prepare_quadruplets(quadruplets, weights) step_sizes = np.logspace(-10, 0, 10) diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 39435dac..3f95babd 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -64,11 +64,15 @@ def fit(self, pairs, y): Parameters ---------- - X : (n x d) data matrix - each row corresponds to a single instance - constraints : 4-tuple of arrays - (a,b,c,d) indices into X, with (a,b) specifying similar and (c,d) - dissimilar pairs + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. + + Returns + ------- + self : object + Returns the instance. """ pairs, y = self._process_pairs(pairs, y) if self.diagonal: diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 213e2904..9378e260 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -62,10 +62,10 @@ def fit(self, pairs, y): Parameters ---------- - X : array-like, shape (n, d) - data matrix, where each row corresponds to a single instance - W : array-like, shape (n, n) - connectivity graph, with +1 for positive pairs and -1 for negative + pairs: array-like, shape=(n_constraints, 2, n_features) + Array of pairs. Each row corresponds to two points. + y: array-like, of shape (n_constraints,) + Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. Returns ------- From 374a8512ff13b8c427c03ed03a36a782fa8ef55c Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 May 2018 15:22:15 +0200 Subject: [PATCH 5/6] Change labels y to be +1/-1 (cf. comment https://github.com/metric-learn/metric-learn/pull/92#discussion_r190559227). --- metric_learn/constraints.py | 2 +- metric_learn/itml.py | 8 +++----- metric_learn/mmc.py | 10 ++++------ metric_learn/sdml.py | 1 - 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/metric_learn/constraints.py b/metric_learn/constraints.py index 3986fce8..17523a46 100644 --- a/metric_learn/constraints.py +++ b/metric_learn/constraints.py @@ -107,6 +107,6 @@ def wrap_pairs(X, constraints): c = np.array(constraints[2]) d = np.array(constraints[3]) constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d)))) - y = np.vstack([np.ones((len(a), 1)), np.zeros((len(c), 1))]) + y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))]) pairs = X[constraints] return pairs, y \ No newline at end of file diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 3d9aff2a..27d1835a 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -54,10 +54,9 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, def _process_pairs(self, pairs, y, bounds): pairs, y = check_X_y(pairs, y, accept_sparse=False, ensure_2d=False, allow_nd=True) - y = y.astype(bool) # check to make sure that no two constrained vectors are identical - pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 pos_pairs = pos_pairs[pos_no_ident] neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 @@ -76,8 +75,7 @@ def _process_pairs(self, pairs, y, bounds): else: self.A_ = check_array(self.A0) pairs = np.vstack([pos_pairs, neg_pairs]) - y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))]) - y = y.astype(bool) + y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) return pairs, y @@ -100,7 +98,7 @@ def fit(self, pairs, y, bounds=None): """ pairs, y = self._process_pairs(pairs, y, bounds) gamma = self.gamma - pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] num_pos = len(pos_pairs) num_neg = len(neg_pairs) _lambda = np.zeros(num_pos + num_neg) diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 3f95babd..46445103 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -83,10 +83,9 @@ def fit(self, pairs, y): def _process_pairs(self, pairs, y): pairs, y = check_X_y(pairs, y, accept_sparse=False, ensure_2d=False, allow_nd=True) - y = y.astype(bool) # check to make sure that no two constrained vectors are identical - pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] pos_no_ident = vector_norm(pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) > 1e-9 pos_pairs = pos_pairs[pos_no_ident] neg_no_ident = vector_norm(neg_pairs[:, 0, :] - neg_pairs[:, 1, :]) > 1e-9 @@ -107,8 +106,7 @@ def _process_pairs(self, pairs, y): self.A_ = check_array(self.A0) pairs = np.vstack([pos_pairs, neg_pairs]) - y = np.hstack([np.ones(len(pos_pairs)), np.zeros(len(neg_pairs))]) - y = y.astype(bool) + y = np.hstack([np.ones(len(pos_pairs)), - np.ones(len(neg_pairs))]) return pairs, y def _fit_full(self, pairs, y): @@ -128,7 +126,7 @@ def _fit_full(self, pairs, y): eps = 0.01 # error-bound of iterative projection on C1 and C2 A = self.A_ - pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] # Create weight vector from similar samples pos_diff = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] @@ -244,7 +242,7 @@ def _fit_diag(self, pairs, y): dissimilar pairs """ num_dim = pairs.shape[2] - pos_pairs, neg_pairs = pairs[y], pairs[~y] + pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] s_sum = np.sum((pos_pairs[:, 0, :] - pos_pairs[:, 1, :]) ** 2, axis=0) it = 0 diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 9378e260..a0a5be38 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -135,5 +135,4 @@ def fit(self, X, y, random_state=np.random): pos_neg = c.positive_negative_pairs(num_constraints, random_state=random_state) pairs, y = wrap_pairs(X, pos_neg) - y = 2 * y - 1 return SDML.fit(self, pairs, y) From b4bdec4bf30875253e6a60e8f6c22863daa2bf7d Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Thu, 24 May 2018 15:33:51 +0200 Subject: [PATCH 6/6] update docstrings with change for +1/-1 labels (see https://github.com/metric-learn/metric-learn/pull/92#discussion_r190584573) --- metric_learn/itml.py | 2 +- metric_learn/mmc.py | 2 +- metric_learn/sdml.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 27d1835a..4d719591 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -87,7 +87,7 @@ def fit(self, pairs, y, bounds=None): pairs: array-like, shape=(n_constraints, 2, n_features) Array of pairs. Each row corresponds to two points. y: array-like, of shape (n_constraints,) - Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. bounds : list (pos,neg) pairs, optional bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index 46445103..a72fa14b 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -67,7 +67,7 @@ def fit(self, pairs, y): pairs: array-like, shape=(n_constraints, 2, n_features) Array of pairs. Each row corresponds to two points. y: array-like, of shape (n_constraints,) - Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. Returns ------- diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index a0a5be38..19919ab1 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -65,7 +65,7 @@ def fit(self, pairs, y): pairs: array-like, shape=(n_constraints, 2, n_features) Array of pairs. Each row corresponds to two points. y: array-like, of shape (n_constraints,) - Labels of constraints. Should be 0 for dissimilar pair, 1 for similar. + Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. Returns -------