diff --git a/metric_learn/__init__.py b/metric_learn/__init__.py index 92823fb1..addf2e1b 100644 --- a/metric_learn/__init__.py +++ b/metric_learn/__init__.py @@ -1,7 +1,7 @@ from .constraints import Constraints from .covariance import Covariance from .itml import ITML, ITML_Supervised -from .lmnn import LMNN +from .lmnn import LMNN, LMNN_Supervised from .lsml import LSML, LSML_Supervised from .sdml import SDML, SDML_Supervised from .nca import NCA @@ -14,7 +14,7 @@ from ._version import __version__ __all__ = ['Constraints', 'Covariance', 'ITML', 'ITML_Supervised', - 'LMNN', 'LSML', 'LSML_Supervised', 'SDML', + 'LMNN', 'LMNN_Supervised', 'LSML', 'LSML_Supervised', 'SDML', 'SDML_Supervised', 'NCA', 'LFDA', 'RCA', 'RCA_Supervised', 'MLKR', 'MMC', 'MMC_Supervised', 'SCML', 'SCML_Supervised', '__version__'] diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index 47bb065f..8404790f 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -4,21 +4,233 @@ import numpy as np from collections import Counter from sklearn.metrics import euclidean_distances +from sklearn.metrics.pairwise import paired_euclidean_distances from sklearn.base import TransformerMixin import warnings from ._util import _initialize_components, _check_n_components -from .base_metric import MahalanobisMixin +from .base_metric import _TripletsClassifierMixin, MahalanobisMixin -class LMNN(MahalanobisMixin, TransformerMixin): +class _BaseLMNN(MahalanobisMixin, TransformerMixin, _TripletsClassifierMixin): + """Large Margin Nearest Neighbor (LMNN)""" + + _tuple_size = 3 # constraints are triplets + + def __init__(self, init='auto', n_neighbors='infer', min_iter=50, max_iter=1000, + learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, + verbose=False, preprocessor=None, + n_components=None, random_state=None, k='deprecated'): + self.init = init + if k != 'deprecated': + warnings.warn('"num_chunks" parameter has been renamed to' + ' "n_chunks". It has been deprecated in' + ' version 0.6.3 and will be removed in 0.7.0' + '', FutureWarning) + n_neighbors = k + self.k = 'deprecated' # To avoid no_attribute error + self.n_neighbors = n_neighbors + self.min_iter = min_iter + self.max_iter = max_iter + self.learn_rate = learn_rate + self.regularization = regularization + self.convergence_tol = convergence_tol + self.verbose = verbose + self.n_components = n_components + self.random_state = random_state + super(_BaseLMNN, self).__init__(preprocessor) + + def _fit(self, triplets): + """ + Optimization procedure to find a sparse vector of weights to + construct the metric from the basis set. This is based on the + dual averaging method. + """ + + if not isinstance(self.max_iter, int): + raise ValueError("max_iter should be an integer, instead it is of type" + " %s" % type(self.max_iter)) + if not isinstance(self.min_iter, int): + raise ValueError("max_iter should be an integer, instead it is of type" + " %s" % type(self.min_iter)) + + if(self.min_iter > self.max_iter): + raise ValueError("The value of min_iter must be equal or smaller than" + " max_iter.") + reg = self.regularization + learn_rate = self.learn_rate + + ## Prepare inputs + # Currently prepare_inputs makes triplets contain points and not indices + triplets = self._prepare_inputs(triplets, type_of_inputs='tuples') + + if self.n_neighbors == 'infer': + self.n_neighbors = k = self._infer_n_neighbors(triplets) + else: + k = self.n_neighbors + + # TODO: + # This algorithm is built to work with indices, but in order to be + # compliant with the current handling of inputs it is converted + # back to indices by the following function. This should be improved + # in the future. + triplets, X, label_mask = self._to_index_points(triplets) + num_pts, d = X.shape + + output_dim = _check_n_components(d, self.n_components) + + self.components_ = _initialize_components(output_dim, X, None, self.init, + self.verbose, + random_state=self.random_state, + has_classes=False) + + + target_neighbors = self._select_targets(X, label_mask) + dfG = _sum_outer_products(X, target_neighbors.flatten(), + np.repeat(np.arange(X.shape[0]), k)) + + # initialize L + L = self.components_ + + # first iteration: we compute variables (including objective and gradient) + # at initialization point + G, objective, total_active = self._loss_grad(X, L, dfG, k, + reg, target_neighbors, + label_mask) + + it = 1 # we already made one iteration + + if self.verbose: + print("iter | objective | objective difference | active constraints", + "| learning rate") + + # main loop + for it in range(2, self.max_iter): + # then at each iteration, we try to find a value of L that has better + # objective than the previous L, following the gradient: + while True: + # the next point next_L to try out is found by a gradient step + L_next = L - learn_rate * G + # we compute the objective at next point + # we copy variables that can be modified by _loss_grad, because if we + # retry we don t want to modify them several times + (G_next, objective_next, total_active_next) = ( + self._loss_grad(X, L_next, dfG, k, reg, target_neighbors, + label_mask)) + assert not np.isnan(objective) + delta_obj = objective_next - objective + if delta_obj > 0: + # if we did not find a better objective, we retry with an L closer to + # the starting point, by decreasing the learning rate (making the + # gradient step smaller) + learn_rate /= 2 + else: + # otherwise, if we indeed found a better obj, we get out of the loop + break + # when the better L is found (and the related variables), we set the + # old variables to these new ones before next iteration and we + # slightly increase the learning rate + L = L_next + G, objective, total_active = G_next, objective_next, total_active_next + learn_rate *= 1.01 + + if self.verbose: + print(it, objective, delta_obj, total_active, learn_rate) + + # check for convergence + if it > self.min_iter and abs(delta_obj) < self.convergence_tol: + if self.verbose: + print("LMNN converged with objective", objective) + break + else: + if self.verbose: + print("LMNN didn't converge in %d steps." % self.max_iter) + + # store the last L + self.components_ = L + self.n_iter_ = it + + def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_mask): + # Compute pairwise distances under current metric + Lx = L.dot(X.T).T + # we need to find the furthest neighbor: + Ni = 1 + _inplace_paired_L2(Lx[target_neighbors], Lx[:, None, :]) + furthest_neighbors = np.take_along_axis(target_neighbors, + Ni.argmax(axis=1)[:, None], 1) + impostors = self._find_impostors(furthest_neighbors.ravel(), X, + label_mask, L) + g0 = _inplace_paired_L2(*Lx[impostors]) + + # # we reorder the target neighbors + g1, g2 = Ni[impostors] + # # compute the gradient + total_active = 0 + df = np.zeros((X.shape[1], X.shape[1])) + for nn_idx in reversed(range(k)): # note: reverse not useful here + act1 = g0 < g1[:, nn_idx] + act2 = g0 < g2[:, nn_idx] + total_active += act1.sum() + act2.sum() + + targets = target_neighbors[:, nn_idx] + PLUS, pweight = _count_edges(act1, act2, impostors, targets) + df += _sum_outer_products(X, PLUS[:, 0], PLUS[:, 1], pweight) + + in_imp, out_imp = impostors + df -= _sum_outer_products(X, in_imp[act1], out_imp[act1]) + df -= _sum_outer_products(X, in_imp[act2], out_imp[act2]) + + # do the gradient update + assert not np.isnan(df).any() + G = dfG * reg + df * (1 - reg) + G = L.dot(G) + # compute the objective function + objective = total_active * (1 - reg) + objective += G.flatten().dot(L.flatten()) + return 2 * G, objective, total_active + + + def _to_index_points(self, triplets): + shape = triplets.shape + X, triplets = np.unique(np.vstack(triplets), return_inverse=True, axis=0) + triplets = triplets.reshape(shape[:2]) + ## pairwise label mask implied by the triplets + label_mask = np.zeros((X.shape[0], X.shape[0])) + for (i, j, k) in triplets: + label_mask[i,j] = 1 + label_mask[i,k] = -1 + return triplets, X, label_mask + + def _select_targets(self, X, label_mask): + dd = euclidean_distances(X, squared=True) + np.fill_diagonal(dd, np.inf) + dd[label_mask==-1] = np.inf + return np.argsort(dd)[..., :self.n_neighbors] + + def _find_impostors(self, furthest_neighbors, X, label_mask, L): + Lx = X.dot(L.T) + margin_radii = 1 + _inplace_paired_L2(Lx[furthest_neighbors], Lx) + impostors = [] + # Find imposters from label mask + (i, j) = (label_mask==-1).nonzero() + dist = np.square(paired_euclidean_distances(Lx[i], Lx[j])) + imposters = dist < margin_radii + return np.array([i[imposters],j[imposters]]) + + def _infer_n_neighbors(self, triplets): + # from the triplet index set, infer the number of neighbors via the number + # of unique target/imposters + target/genuine pairs + k_genuine = np.unique(triplets[:,[0,2]], axis=0, return_counts=True)[1].min() + k_imposter = np.unique(triplets[:,[0,1]], axis=0, return_counts=True)[1].min() + return k_genuine + k_imposter + + +class LMNN(_BaseLMNN): """Large Margin Nearest Neighbor (LMNN) - LMNN learns a Mahalanobis distance metric in the kNN classification - setting. The learned metric attempts to keep close k-nearest neighbors - from the same class, while keeping examples from different classes - separated by a large margin. This algorithm makes no assumptions about - the distribution of the data. + LMNN learns a Mahalanobis distance metric from triplets. + The learned metric attempts to keep close k-nearest neighbors, while keeping + potential imposters separated by a large margin. This algorithm makes no + assumptions about the distribution of the data. Read more in the :ref:`User Guide `. @@ -42,13 +254,6 @@ class LMNN(MahalanobisMixin, TransformerMixin): to :meth:`fit` will be used to initialize the transformation. (See `sklearn.decomposition.PCA`) - 'lda' - ``min(n_components, n_classes)`` most discriminative - components of the inputs passed to :meth:`fit` will be used to - initialize the transformation. (If ``n_components > n_classes``, - the rest of the components will be zero.) (See - `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) - 'identity' If ``n_components`` is strictly smaller than the dimensionality of the inputs passed to :meth:`fit`, the identity @@ -64,8 +269,10 @@ class LMNN(MahalanobisMixin, TransformerMixin): :meth:`fit` and n_features_a must be less than or equal to that. If ``n_components`` is not None, n_features_a must match it. - n_neighbors : int, optional (default=3) - Number of neighbors to consider, not including self-edges. + n_neighbors : int, optional (default='infer') + Number of neighbors to consider, not including self-edges. If 'infer', + the number of nearest neighbors can be inferred by counting point pairs + from amoung the provided triplets. min_iter : int, optional (default=50) Minimum number of iterations of the optimization procedure. @@ -131,6 +338,121 @@ class LMNN(MahalanobisMixin, TransformerMixin): 2005. """ + def fit(self, triplets): + """Learn the LMNN model. + + Parameters + ---------- + triplets : array-like, shape=(n_constraints, 3, n_features) or \ + (n_constraints, 3) + 3D array-like of triplets of points or 2D array of triplets of + indicators. Triplets are assumed to be ordered such that: + d(triplets[i, 0],triplets[i, 1]) < d(triplets[i, 0], triplets[i, 2]). + + Returns + ------- + self : object + Returns the instance. + """ + + return self._fit(triplets) + + +class LMNN_Supervised(MahalanobisMixin, TransformerMixin): + """Large Margin Nearest Neighbor (LMNN) + LMNN learns a Mahalanobis distance metric in the kNN classification + setting. The learned metric attempts to keep close k-nearest neighbors + from the same class, while keeping examples from different classes + separated by a large margin. This algorithm makes no assumptions about + the distribution of the data. + Read more in the :ref:`User Guide `. + Parameters + ---------- + init : string or numpy array, optional (default='auto') + Initialization of the linear transformation. Possible options are + 'auto', 'pca', 'identity', 'random', and a numpy array of shape + (n_features_a, n_features_b). + 'auto' + Depending on ``n_components``, the most reasonable initialization + will be chosen. If ``n_components <= n_classes`` we use 'lda', as + it uses labels information. If not, but + ``n_components < min(n_features, n_samples)``, we use 'pca', as + it projects data in meaningful directions (those of higher + variance). Otherwise, we just use 'identity'. + 'pca' + ``n_components`` principal components of the inputs passed + to :meth:`fit` will be used to initialize the transformation. + (See `sklearn.decomposition.PCA`) + 'lda' + ``min(n_components, n_classes)`` most discriminative + components of the inputs passed to :meth:`fit` will be used to + initialize the transformation. (If ``n_components > n_classes``, + the rest of the components will be zero.) (See + `sklearn.discriminant_analysis.LinearDiscriminantAnalysis`) + 'identity' + If ``n_components`` is strictly smaller than the + dimensionality of the inputs passed to :meth:`fit`, the identity + matrix will be truncated to the first ``n_components`` rows. + 'random' + The initial transformation will be a random array of shape + `(n_components, n_features)`. Each value is sampled from the + standard normal distribution. + numpy array + n_features_b must match the dimensionality of the inputs passed to + :meth:`fit` and n_features_a must be less than or equal to that. + If ``n_components`` is not None, n_features_a must match it. + n_neighbors : int, optional (default=3) + Number of neighbors to consider, not including self-edges. + min_iter : int, optional (default=50) + Minimum number of iterations of the optimization procedure. + max_iter : int, optional (default=1000) + Maximum number of iterations of the optimization procedure. + learn_rate : float, optional (default=1e-7) + Learning rate of the optimization procedure + tol : float, optional (default=0.001) + Tolerance of the optimization procedure. If the objective value varies + less than `tol`, we consider the algorithm has converged and stop it. + verbose : bool, optional (default=False) + Whether to print the progress of the optimization procedure. + regularization: float, optional (default=0.5) + Relative weight between pull and push terms, with 0.5 meaning equal + weight. + preprocessor : array-like, shape=(n_samples, n_features) or callable + The preprocessor to call to get tuples from indices. If array-like, + tuples will be formed like this: X[indices]. + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + random_state : int or numpy.RandomState or None, optional (default=None) + A pseudo random number generator object or a seed for it if int. If + ``init='random'``, ``random_state`` is used to initialize the random + transformation. If ``init='pca'``, ``random_state`` is passed as an + argument to PCA when initializing the transformation. + k : Renamed to n_neighbors. Will be deprecated in 0.7.0 + Attributes + ---------- + n_iter_ : `int` + The number of iterations the solver has run. + components_ : `numpy.ndarray`, shape=(n_components, n_features) + The learned linear transformation ``L``. + Examples + -------- + >>> import numpy as np + >>> from metric_learn import LMNN + >>> from sklearn.datasets import load_iris + >>> iris_data = load_iris() + >>> X = iris_data['data'] + >>> Y = iris_data['target'] + >>> lmnn = LMNN(n_neighbors=5, learn_rate=1e-6) + >>> lmnn.fit(X, Y, verbose=False) + References + ---------- + .. [1] K. Q. Weinberger, J. Blitzer, L. K. Saul. `Distance Metric + Learning for Large Margin Nearest Neighbor Classification + `_. NIPS + 2005. + """ + def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, verbose=False, preprocessor=None, @@ -152,7 +474,7 @@ def __init__(self, init='auto', n_neighbors=3, min_iter=50, max_iter=1000, self.verbose = verbose self.n_components = n_components self.random_state = random_state - super(LMNN, self).__init__(preprocessor) + super(LMNN_Supervised, self).__init__(preprocessor) def fit(self, X, y): k = self.n_neighbors diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index a39c7b3c..3f1abe95 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -19,7 +19,7 @@ HAS_SKGGM = False else: HAS_SKGGM = True -from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, +from metric_learn import (LMNN_Supervised, NCA, LFDA, Covariance, MLKR, MMC, SCML_Supervised, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised, SDML, RCA, ITML, SCML) @@ -381,7 +381,7 @@ def test_bounds_parameters_invalid(bounds): class TestLMNN(MetricTestCase): def test_iris(self): - lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) + lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.iris_points, self.iris_labels) csep = class_separation(lmnn.transform(self.iris_points), @@ -396,7 +396,7 @@ def test_loss_grad_lbfgs(self): rng = np.random.RandomState(42) X, y = make_classification(random_state=rng) L = rng.randn(rng.randint(1, X.shape[1] + 1), X.shape[1]) - lmnn = LMNN() + lmnn = LMNN_Supervised() k = lmnn.n_neighbors reg = lmnn.regularization @@ -499,7 +499,7 @@ def grad(x0): scipy.optimize.check_grad(loss, grad, x0.ravel()) - class LMNN_with_callback(LMNN): + class LMNN_with_callback(LMNN_Supervised): """ We will use a callback to get the gradient (see later) """ @@ -574,7 +574,7 @@ def _loss_grad(self, X, L, dfG, k, reg, target_neighbors, label_inds): def test_toy_ex_lmnn(X, y, loss): """Test that the loss give the right result on a toy example""" L = np.array([[1]]) - lmnn = LMNN(n_neighbors=1, regularization=0.5) + lmnn = LMNN_Supervised(n_neighbors=1, regularization=0.5) k = lmnn.n_neighbors reg = lmnn.regularization @@ -608,7 +608,7 @@ def test_convergence_simple_example(capsys): # LMNN should converge on this simple example, which it did not with # this issue: https://github.com/scikit-learn-contrib/metric-learn/issues/88 X, y = make_classification(random_state=0) - lmnn = LMNN(verbose=True) + lmnn = LMNN_Supervised(verbose=True) lmnn.fit(X, y) out, _ = capsys.readouterr() assert "LMNN converged with objective" in out @@ -618,7 +618,7 @@ def test_no_twice_same_objective(capsys): # test that the objective function never has twice the same value # see https://github.com/scikit-learn-contrib/metric-learn/issues/88 X, y = make_classification(random_state=0) - lmnn = LMNN(verbose=True) + lmnn = LMNN_Supervised(verbose=True) lmnn.fit(X, y) out, _ = capsys.readouterr() lines = re.split("\n+", out) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index fa641526..11d23b54 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -45,9 +45,9 @@ def test_lmnn(self): nndef_kwargs = {'convergence_tol': 0.01, 'n_neighbors': 6} merged_kwargs = sk_repr_kwargs(def_kwargs, nndef_kwargs) self.assertEqual( - remove_spaces(str(metric_learn.LMNN(convergence_tol=0.01, + remove_spaces(str(metric_learn.LMNN_Supervised(convergence_tol=0.01, n_neighbors=6))), - remove_spaces(f"LMNN({merged_kwargs})")) + remove_spaces(f"LMNN_Supervised({merged_kwargs})")) def test_nca(self): def_kwargs = {'init': 'auto', 'max_iter': 100, 'n_components': None, diff --git a/test/test_components_metric_conversion.py b/test/test_components_metric_conversion.py index c6113957..e704b56c 100644 --- a/test/test_components_metric_conversion.py +++ b/test/test_components_metric_conversion.py @@ -7,7 +7,7 @@ from metric_learn.sklearn_shims import ignore_warnings from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, + LMNN_Supervised, NCA, LFDA, Covariance, MLKR, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) from metric_learn._util import components_from_metric from metric_learn.exceptions import NonPSDError @@ -42,7 +42,7 @@ def test_itml_supervised(self): assert_array_almost_equal(L.T.dot(L), itml.get_mahalanobis_matrix()) def test_lmnn(self): - lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) + lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) L = lmnn.components_ assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix()) diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 246223b0..88571fe0 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -4,7 +4,7 @@ from numpy.testing import assert_array_almost_equal from metric_learn import ( - LMNN, NCA, LFDA, Covariance, MLKR, + LMNN_Supervised, NCA, LFDA, Covariance, MLKR, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised) @@ -52,11 +52,11 @@ def test_itml_supervised(self): assert_array_almost_equal(res_1, res_2) def test_lmnn(self): - lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) + lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) res_1 = lmnn.transform(self.X) - lmnn = LMNN(n_neighbors=5, learn_rate=1e-6, verbose=False) + lmnn = LMNN_Supervised(n_neighbors=5, learn_rate=1e-6, verbose=False) res_2 = lmnn.fit_transform(self.X, self.y) assert_array_almost_equal(res_1, res_2) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index b5dbc248..56699baf 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -437,7 +437,7 @@ def test_auto_init_transformation(n_samples, n_features, n_classes, n_components=n_components, random_state=rng) # To make the test work for LMNN: - if 'LMNN' in model_base.__class__.__name__: + if 'LMNN_Supervised' in model_base.__class__.__name__: model_base.set_params(n_neighbors=1) # To make the test faster for estimators that have a max_iter: if hasattr(model_base, 'max_iter'): diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 798d9036..b3fce0c5 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -7,7 +7,7 @@ from metric_learn.sklearn_shims import (assert_allclose_dense_sparse, set_random_state, _get_args, is_public_parameter, get_scorer) -from metric_learn import (Covariance, LFDA, LMNN, MLKR, NCA, +from metric_learn import (Covariance, LFDA, LMNN_Supervised, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, SCML_Supervised) @@ -52,7 +52,7 @@ def test_covariance(self): check_estimator(Covariance()) def test_lmnn(self): - check_estimator(LMNN()) + check_estimator(LMNN_Supervised()) def test_lfda(self): check_estimator(LFDA()) diff --git a/test/test_utils.py b/test/test_utils.py index 43d67111..c75ed613 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ check_y_valid_values_for_pairs, _auto_select_init, _pseudo_inverse_from_eig) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, - LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, + LMNN_Supervised, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, SCML, SCML_Supervised, Constraints) from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, @@ -131,7 +131,7 @@ def build_quadruplets(with_preprocessor=False): classifiers = [(Covariance(), build_classification), (LFDA(), build_classification), - (LMNN(), build_classification), + (LMNN_Supervised(), build_classification), (NCA(), build_classification), (RCA(), build_classification), (ITML_Supervised(max_iter=5), build_classification),