diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 158ec4d3..4316802c 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -73,9 +73,9 @@ def _fit(self, pairs, y, bounds=None): self.bounds_[self.bounds_==0] = 1e-9 # init metric if self.A0 is None: - self.A_ = np.identity(pairs.shape[2]) + A = np.identity(pairs.shape[2]) else: - self.A_ = check_array(self.A0) + A = check_array(self.A0, copy=True) gamma = self.gamma pos_pairs, neg_pairs = pairs[y == 1], pairs[y == -1] num_pos = len(pos_pairs) @@ -87,7 +87,6 @@ def _fit(self, pairs, y, bounds=None): neg_bhat = np.zeros(num_neg) + self.bounds_[1] pos_vv = pos_pairs[:, 0, :] - pos_pairs[:, 1, :] neg_vv = neg_pairs[:, 0, :] - neg_pairs[:, 1, :] - A = self.A_ for it in xrange(self.max_iter): # update positives @@ -125,7 +124,7 @@ def _fit(self, pairs, y, bounds=None): print('itml converged at iter: %d, conv = %f' % (it, conv)) self.n_iter_ = it - self.transformer_ = transformer_from_metric(self.A_) + self.transformer_ = transformer_from_metric(A) return self @@ -134,6 +133,18 @@ class ITML(_BaseITML, _PairsClassifierMixin): Attributes ---------- + bounds_ : array-like, shape=(2,) + Bounds on similarity, aside slack variables, s.t. + ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` + and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of + dissimilar points ``c`` and ``d``, with ``d`` the learned distance. If + not provided at initialization, bounds_[0] and bounds_[1] are set at + train time to the 5th and 95th percentile of the pairwise distances among + all points present in the input `pairs`. + + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -151,8 +162,14 @@ def fit(self, pairs, y, bounds=None): preprocessor. y: array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. + ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` + and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of + dissimilar points ``c`` and ``d``, with ``d`` the learned distance. + If not provided at initialization, bounds_[0] and bounds_[1] will be + set to the 5th and 95th percentile of the pairwise distances among all + points present in the input `pairs`. Returns ------- @@ -167,6 +184,18 @@ class ITML_Supervised(_BaseITML, TransformerMixin): Attributes ---------- + bounds_ : array-like, shape=(2,) + Bounds on similarity, aside slack variables, s.t. + ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` + and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of + dissimilar points ``c`` and ``d``, with ``d`` the learned distance. + If not provided at initialization, bounds_[0] and bounds_[1] are set at + train time to the 5th and 95th percentile of the pairwise distances + among all points in the training data `X`. + + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -193,8 +222,14 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3, be removed in 0.6.0. num_constraints: int, optional number of constraints to generate - bounds : list (pos,neg) pairs, optional - bounds on similarity, s.t. d(X[a],X[b]) < pos and d(X[c],X[d]) > neg + bounds : `list` of two numbers + Bounds on similarity, aside slack variables, s.t. + ``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a`` + and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of + dissimilar points ``c`` and ``d``, with ``d`` the learned distance. + If not provided at initialization, bounds_[0] and bounds_[1] will be + set to the 5th and 95th percentile of the pairwise distances among all + points in the training data `X`. A0 : (d x d) matrix, optional initial regularization matrix, defaults to identity verbose : bool, optional diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 1d7ddf2a..f9cd0e91 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -60,20 +60,20 @@ def fit(self, X, y): X, y = self._prepare_inputs(X, y, dtype=float, ensure_min_samples=2) num_pts, num_dims = X.shape - unique_labels, self.label_inds_ = np.unique(y, return_inverse=True) - if len(self.label_inds_) != num_pts: + unique_labels, label_inds = np.unique(y, return_inverse=True) + if len(label_inds) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) if self.use_pca: warnings.warn('use_pca does nothing for the python_LMNN implementation') self.transformer_ = np.eye(num_dims) - required_k = np.bincount(self.label_inds_).min() + required_k = np.bincount(label_inds).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' ' (smallest class has %d)' % required_k) - target_neighbors = self._select_targets(X) - impostors = self._find_impostors(target_neighbors[:, -1], X) + target_neighbors = self._select_targets(X, label_inds) + impostors = self._find_impostors(target_neighbors[:, -1], X, label_inds) if len(impostors) == 0: # L has already been initialized to an identity matrix return @@ -196,23 +196,23 @@ def _loss_grad(self, X, L, dfG, impostors, it, k, reg, target_neighbors, df, objective += G.flatten().dot(L.T.dot(L).flatten()) return G, objective, total_active, df, a1, a2 - def _select_targets(self, X): + def _select_targets(self, X, label_inds): target_neighbors = np.empty((X.shape[0], self.k), dtype=int) for label in self.labels_: - inds, = np.nonzero(self.label_inds_ == label) + inds, = np.nonzero(label_inds == label) dd = euclidean_distances(X[inds], squared=True) np.fill_diagonal(dd, np.inf) nn = np.argsort(dd)[..., :self.k] target_neighbors[inds] = inds[nn] return target_neighbors - def _find_impostors(self, furthest_neighbors, X): + def _find_impostors(self, furthest_neighbors, X, label_inds): Lx = self.transform(X) margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) impostors = [] for label in self.labels_[:-1]: - in_inds, = np.nonzero(self.label_inds_ == label) - out_inds, = np.nonzero(self.label_inds_ > label) + in_inds, = np.nonzero(label_inds == label) + out_inds, = np.nonzero(label_inds > label) dist = euclidean_distances(Lx[out_inds], Lx[in_inds], squared=True) i1,j1 = np.nonzero(dist < margin_radii[out_inds][:,None]) i2,j2 = np.nonzero(dist < margin_radii[in_inds]) @@ -265,6 +265,9 @@ class LMNN(_base_LMNN): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 50fcfa3e..312990ab 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -50,32 +50,32 @@ def _fit(self, quadruplets, y=None, weights=None): type_of_inputs='tuples') # check to make sure that no two constrained vectors are identical - self.vab_ = quadruplets[:, 0, :] - quadruplets[:, 1, :] - self.vcd_ = quadruplets[:, 2, :] - quadruplets[:, 3, :] - if self.vab_.shape != self.vcd_.shape: + vab = quadruplets[:, 0, :] - quadruplets[:, 1, :] + vcd = quadruplets[:, 2, :] - quadruplets[:, 3, :] + if vab.shape != vcd.shape: raise ValueError('Constraints must have same length') if weights is None: - self.w_ = np.ones(self.vab_.shape[0]) + self.w_ = np.ones(vab.shape[0]) else: self.w_ = weights self.w_ /= self.w_.sum() # weights must sum to 1 if self.prior is None: X = np.vstack({tuple(row) for row in quadruplets.reshape(-1, quadruplets.shape[2])}) - self.prior_inv_ = np.atleast_2d(np.cov(X, rowvar=False)) - self.M_ = np.linalg.inv(self.prior_inv_) + prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) + M = np.linalg.inv(prior_inv) else: - self.M_ = self.prior - self.prior_inv_ = np.linalg.inv(self.prior) + M = self.prior + prior_inv = np.linalg.inv(self.prior) step_sizes = np.logspace(-10, 0, 10) # Keep track of the best step size and the loss at that step. l_best = 0 - s_best = self._total_loss(self.M_) + s_best = self._total_loss(M, vab, vcd, prior_inv) if self.verbose: print('initial loss', s_best) for it in xrange(1, self.max_iter+1): - grad = self._gradient(self.M_) + grad = self._gradient(M, vab, vcd, prior_inv) grad_norm = scipy.linalg.norm(grad) if grad_norm < self.tol: break @@ -84,10 +84,10 @@ def _fit(self, quadruplets, y=None, weights=None): M_best = None for step_size in step_sizes: step_size /= grad_norm - new_metric = self.M_ - step_size * grad + new_metric = M - step_size * grad w, v = scipy.linalg.eigh(new_metric) new_metric = v.dot((np.maximum(w, 1e-8) * v).T) - cur_s = self._total_loss(new_metric) + cur_s = self._total_loss(new_metric, vab, vcd, prior_inv) if cur_s < s_best: l_best = step_size s_best = cur_s @@ -96,36 +96,36 @@ def _fit(self, quadruplets, y=None, weights=None): print('iter', it, 'cost', s_best, 'best step', l_best * grad_norm) if M_best is None: break - self.M_ = M_best + M = M_best else: if self.verbose: print("Didn't converge after", it, "iterations. Final loss:", s_best) self.n_iter_ = it - self.transformer_ = transformer_from_metric(self.M_) + self.transformer_ = transformer_from_metric(M) return self - def _comparison_loss(self, metric): - dab = np.sum(self.vab_.dot(metric) * self.vab_, axis=1) - dcd = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1) + def _comparison_loss(self, metric, vab, vcd): + dab = np.sum(vab.dot(metric) * vab, axis=1) + dcd = np.sum(vcd.dot(metric) * vcd, axis=1) violations = dab > dcd return self.w_[violations].dot((np.sqrt(dab[violations]) - np.sqrt(dcd[violations]))**2) - def _total_loss(self, metric): + def _total_loss(self, metric, vab, vcd, prior_inv): # Regularization loss sign, logdet = np.linalg.slogdet(metric) - reg_loss = np.sum(metric * self.prior_inv_) - sign * logdet - return self._comparison_loss(metric) + reg_loss + reg_loss = np.sum(metric * prior_inv) - sign * logdet + return self._comparison_loss(metric, vab, vcd) + reg_loss - def _gradient(self, metric): - dMetric = self.prior_inv_ - np.linalg.inv(metric) - dabs = np.sum(self.vab_.dot(metric) * self.vab_, axis=1) - dcds = np.sum(self.vcd_.dot(metric) * self.vcd_, axis=1) + def _gradient(self, metric, vab, vcd, prior_inv): + dMetric = prior_inv - np.linalg.inv(metric) + dabs = np.sum(vab.dot(metric) * vab, axis=1) + dcds = np.sum(vcd.dot(metric) * vcd, axis=1) violations = dabs > dcds # TODO: vectorize - for vab, dab, vcd, dcd in zip(self.vab_[violations], dabs[violations], - self.vcd_[violations], dcds[violations]): + for vab, dab, vcd, dcd in zip(vab[violations], dabs[violations], + vcd[violations], dcds[violations]): dMetric += ((1-np.sqrt(dcd/dab))*np.outer(vab, vab) + (1-np.sqrt(dab/dcd))*np.outer(vcd, vcd)) return dMetric @@ -136,6 +136,9 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -169,6 +172,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 6b79638e..74a21a82 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -30,6 +30,9 @@ class MLKR(MahalanobisMixin, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index b806a97e..f9d3690b 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -353,6 +353,9 @@ class MMC(_BaseMMC, _PairsClassifierMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -384,6 +387,9 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 81045287..5abe52e3 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -24,6 +24,9 @@ class NCA(MahalanobisMixin, TransformerMixin): Attributes ---------- + n_iter_ : `int` + The number of iterations the solver has run. + transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index be45d3a3..78fc4ebc 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -58,18 +58,18 @@ def _fit(self, pairs, y): # set up prior M if self.use_cov: X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) - self.M_ = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) + M = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) else: - self.M_ = np.identity(pairs.shape[2]) + M = np.identity(pairs.shape[2]) diff = pairs[:, 0] - pairs[:, 1] loss_matrix = (diff.T * y).dot(diff) - P = self.M_ + self.balance_param * loss_matrix + P = M + self.balance_param * loss_matrix emp_cov = pinvh(P) # hack: ensure positive semidefinite emp_cov = emp_cov.T.dot(emp_cov) - _, self.M_ = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) + _, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) - self.transformer_ = transformer_from_metric(self.M_) + self.transformer_ = transformer_from_metric(M) return self