diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 105d89b5..583f1105 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -411,3 +411,13 @@ def validate_vector(u, dtype=None): if u.ndim > 1: raise ValueError("Input vector should be 1-D.") return u + + +def _check_n_components(n_features, n_components): + """Checks that n_components is less than n_features and deal with the None + case""" + if n_components is None: + return n_features + if 0 < n_components <= n_features: + return n_components + raise ValueError('Invalid n_components, must be in [1, %d]' % n_features) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 9f127f58..856591cb 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -172,7 +172,7 @@ class MahalanobisMixin(six.with_metaclass(ABCMeta, BaseMetricLearner, Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ @@ -232,7 +232,7 @@ def transform(self, X): Returns ------- - X_embedded : `numpy.ndarray`, shape=(n_samples, num_dims) + X_embedded : `numpy.ndarray`, shape=(n_samples, n_components) The embedded data points. """ X_checked = check_input(X, type_of_inputs='classic', estimator=self, @@ -288,7 +288,7 @@ def get_mahalanobis_matrix(self): Returns ------- - M : `numpy.ndarray`, shape=(n_components, n_features) + M : `numpy.ndarray`, shape=(n_features, n_features) The copy of the learned Mahalanobis matrix. """ return self.transformer_.T.dot(self.transformer_) diff --git a/metric_learn/covariance.py b/metric_learn/covariance.py index 7a04923d..05012110 100644 --- a/metric_learn/covariance.py +++ b/metric_learn/covariance.py @@ -21,7 +21,7 @@ class Covariance(MahalanobisMixin, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ diff --git a/metric_learn/itml.py b/metric_learn/itml.py index e3ff515a..25518bf6 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -150,7 +150,7 @@ class ITML(_BaseITML, _PairsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -218,7 +218,7 @@ class ITML_Supervised(_BaseITML, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -292,11 +292,11 @@ def fit(self, X, y, random_state=np.random, bounds=None): if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0', DeprecationWarning) + ' removed in 0.6.0', DeprecationWarning) if self.bounds != 'deprecated': warnings.warn('"bounds" parameter from initialization is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0. Use the "bounds" parameter of this ' + ' removed in 0.6.0. Use the "bounds" parameter of this ' 'fit method instead.', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints diff --git a/metric_learn/lfda.py b/metric_learn/lfda.py index 2ca085d4..1851a734 100644 --- a/metric_learn/lfda.py +++ b/metric_learn/lfda.py @@ -16,6 +16,8 @@ from six.moves import xrange from sklearn.metrics import pairwise_distances from sklearn.base import TransformerMixin + +from ._util import _check_n_components from .base_metric import MahalanobisMixin @@ -26,23 +28,29 @@ class LFDA(MahalanobisMixin, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. ''' - def __init__(self, num_dims=None, k=None, embedding_type='weighted', - preprocessor=None): + def __init__(self, n_components=None, num_dims='deprecated', + k=None, embedding_type='weighted', preprocessor=None): ''' Initialize LFDA. Parameters ---------- - num_dims : int, optional - Dimensionality of reduced space (defaults to dimension of X) + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. k : int, optional Number of nearest neighbors used in local scaling method. - Defaults to min(7, num_dims - 1). + Defaults to min(7, n_components - 1). embedding_type : str, optional Type of metric in the embedding space (default: 'weighted') @@ -56,6 +64,7 @@ def __init__(self, num_dims=None, k=None, embedding_type='weighted', ''' if embedding_type not in ('weighted', 'orthonormalized', 'plain'): raise ValueError('Invalid embedding_type: %r' % embedding_type) + self.n_components = n_components self.num_dims = num_dims self.embedding_type = embedding_type self.k = k @@ -72,17 +81,17 @@ def fit(self, X, y): y : (n,) array-like Class labels, one per point of data. ''' + if self.num_dims != 'deprecated': + warnings.warn('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead', + DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) unique_classes, y = np.unique(y, return_inverse=True) n, d = X.shape num_classes = len(unique_classes) - if self.num_dims is None: - dim = d - else: - if not 0 < self.num_dims <= d: - raise ValueError('Invalid num_dims, must be in [1,%d]' % d) - dim = self.num_dims + dim = _check_n_components(d, self.n_components) if self.k is None: k = min(7, d - 1) diff --git a/metric_learn/lmnn.py b/metric_learn/lmnn.py index d70ca3d0..1ba87684 100644 --- a/metric_learn/lmnn.py +++ b/metric_learn/lmnn.py @@ -19,6 +19,8 @@ from six.moves import xrange from sklearn.metrics import euclidean_distances from sklearn.base import TransformerMixin + +from ._util import _check_n_components from .base_metric import MahalanobisMixin @@ -26,7 +28,8 @@ class _base_LMNN(MahalanobisMixin, TransformerMixin): def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, regularization=0.5, convergence_tol=0.001, use_pca=True, - verbose=False, preprocessor=None): + verbose=False, preprocessor=None, n_components=None, + num_dims='deprecated'): """Initialize the LMNN object. Parameters @@ -40,6 +43,15 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. + + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. """ self.k = k self.min_iter = min_iter @@ -49,6 +61,8 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, self.convergence_tol = convergence_tol self.use_pca = use_pca self.verbose = verbose + self.n_components = n_components + self.num_dims = num_dims super(_base_LMNN, self).__init__(preprocessor) @@ -56,20 +70,26 @@ def __init__(self, k=3, min_iter=50, max_iter=1000, learn_rate=1e-7, class python_LMNN(_base_LMNN): def fit(self, X, y): + if self.num_dims != 'deprecated': + warnings.warn('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead', + DeprecationWarning) k = self.k reg = self.regularization learn_rate = self.learn_rate X, y = self._prepare_inputs(X, y, dtype=float, ensure_min_samples=2) - num_pts, num_dims = X.shape + num_pts, d = X.shape + output_dim = _check_n_components(d, self.n_components) unique_labels, label_inds = np.unique(y, return_inverse=True) if len(label_inds) != num_pts: raise ValueError('Must have one label per point.') self.labels_ = np.arange(len(unique_labels)) if self.use_pca: warnings.warn('use_pca does nothing for the python_LMNN implementation') - self.transformer_ = np.eye(num_dims) + self.transformer_ = np.eye(output_dim, d) required_k = np.bincount(label_inds).min() if self.k > required_k: raise ValueError('not enough class labels for specified k' @@ -272,7 +292,7 @@ class LMNN(_base_LMNN): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 1d66cbc0..94366b88 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -146,7 +146,7 @@ class LSML(_BaseLSML, _QuadrupletsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -182,7 +182,7 @@ class LSML_Supervised(_BaseLSML, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -241,7 +241,7 @@ def fit(self, X, y, random_state=np.random): if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0', DeprecationWarning) + ' removed in 0.6.0', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: diff --git a/metric_learn/mlkr.py b/metric_learn/mlkr.py index 927c64e3..762317b9 100644 --- a/metric_learn/mlkr.py +++ b/metric_learn/mlkr.py @@ -23,6 +23,8 @@ from sklearn.metrics import pairwise_distances + +from metric_learn._util import _check_n_components from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps @@ -36,19 +38,25 @@ class MLKR(MahalanobisMixin, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000, - verbose=False, preprocessor=None): + def __init__(self, n_components=None, num_dims='deprecated', A0=None, + tol=None, max_iter=1000, verbose=False, preprocessor=None): """ Initialize MLKR. Parameters ---------- - num_dims : int, optional - Dimensionality of reduced space (defaults to dimension of X) + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. A0: array-like, optional Initialization of transformation matrix. Defaults to PCA loadings. @@ -66,6 +74,7 @@ def __init__(self, num_dims=None, A0=None, tol=None, max_iter=1000, The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. """ + self.n_components = n_components self.num_dims = num_dims self.A0 = A0 self.tol = tol @@ -82,6 +91,11 @@ def fit(self, X, y): X : (n x d) array of samples y : (n) data labels """ + if self.num_dims != 'deprecated': + warnings.warn('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead', + DeprecationWarning) X, y = self._prepare_inputs(X, y, y_numeric=True, ensure_min_samples=2) n, d = X.shape @@ -90,7 +104,8 @@ def fit(self, X, y): % (n, y.shape[0])) A = self.A0 - m = self.num_dims + m = _check_n_components(d, self.n_components) + m = self.n_components if m is None: m = d if A is None: diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index eb7dc529..0e6cd5cb 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -356,7 +356,7 @@ class MMC(_BaseMMC, _PairsClassifierMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -406,7 +406,7 @@ class MMC_Supervised(_BaseMMC, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -469,7 +469,7 @@ def fit(self, X, y, random_state=np.random): if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0', DeprecationWarning) + ' removed in 0.6.0', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: diff --git a/metric_learn/nca.py b/metric_learn/nca.py index 7139f0ff..3545aa89 100644 --- a/metric_learn/nca.py +++ b/metric_learn/nca.py @@ -23,6 +23,7 @@ from sklearn.utils.fixes import logsumexp from sklearn.base import TransformerMixin +from ._util import _check_n_components from .base_metric import MahalanobisMixin EPS = np.finfo(float).eps @@ -36,19 +37,24 @@ class NCA(MahalanobisMixin, TransformerMixin): n_iter_ : `int` The number of iterations the solver has run. - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, max_iter=100, tol=None, verbose=False, - preprocessor=None): + def __init__(self, n_components=None, num_dims='deprecated', max_iter=100, + tol=None, verbose=False, preprocessor=None): """Neighborhood Components Analysis Parameters ---------- - num_dims : int, optional (default=None) - Embedding dimensionality. If None, will be set to ``n_features`` - (``d``) at fit time. + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. max_iter : int, optional (default=100) Maximum number of iterations done by the optimization algorithm. @@ -59,6 +65,7 @@ def __init__(self, num_dims=None, max_iter=100, tol=None, verbose=False, verbose : bool, optional (default=False) Whether to print progress messages or not. """ + self.n_components = n_components self.num_dims = num_dims self.max_iter = max_iter self.tol = tol @@ -70,18 +77,21 @@ def fit(self, X, y): X: data matrix, (n x d) y: scalar labels, (n) """ + if self.num_dims != 'deprecated': + warnings.warn('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead', + DeprecationWarning) X, labels = self._prepare_inputs(X, y, ensure_min_samples=2) n, d = X.shape - num_dims = self.num_dims - if num_dims is None: - num_dims = d + n_components = _check_n_components(d, self.n_components) # Measure the total training time train_time = time.time() # Initialize A to a scaling matrix - A = np.zeros((num_dims, d)) - np.fill_diagonal(A, 1./(np.maximum(X.max(axis=0)-X.min(axis=0), EPS))) + A = np.zeros((n_components, d)) + np.fill_diagonal(A, 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) # Run NCA mask = labels[:, np.newaxis] == labels[np.newaxis, :] @@ -130,7 +140,7 @@ def _loss_grad_lbfgs(self, A, X, mask, sign=1.0): start_time = time.time() A = A.reshape(-1, X.shape[1]) - X_embedded = np.dot(X, A.T) # (n_samples, num_dims) + X_embedded = np.dot(X, A.T) # (n_samples, n_components) # Compute softmax distances p_ij = pairwise_distances(X_embedded, squared=True) np.fill_diagonal(p_ij, np.inf) diff --git a/metric_learn/rca.py b/metric_learn/rca.py index 7d0bb21f..45c9bbf2 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -18,6 +18,7 @@ from sklearn import decomposition from sklearn.base import TransformerMixin +from ._util import _check_n_components from .base_metric import MahalanobisMixin from .constraints import Constraints @@ -42,17 +43,24 @@ class RCA(MahalanobisMixin, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, pca_comps=None, preprocessor=None): + def __init__(self, n_components=None, num_dims='deprecated', + pca_comps=None, preprocessor=None): """Initialize the learner. Parameters ---------- - num_dims : int, optional - embedding dimension (default: original dimension of data) + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. pca_comps : int, float, None or string Number of components to keep during PCA preprocessing. @@ -65,6 +73,7 @@ def __init__(self, num_dims=None, pca_comps=None, preprocessor=None): The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. """ + self.n_components = n_components self.num_dims = num_dims self.pca_comps = pca_comps super(RCA, self).__init__(preprocessor) @@ -77,16 +86,7 @@ def _check_dimension(self, rank, X): 'You should adjust pca_comps to remove noise and ' 'redundant information.') - if self.num_dims is None: - dim = d - elif self.num_dims <= 0: - raise ValueError('Invalid embedding dimension: must be greater than 0.') - elif self.num_dims > d: - dim = d - warnings.warn('num_dims (%d) must be smaller than ' - 'the data dimension (%d)' % (self.num_dims, d)) - else: - dim = self.num_dims + dim = _check_n_components(d, self.n_components) return dim def fit(self, X, chunks): @@ -100,6 +100,11 @@ def fit(self, X, chunks): When ``chunks[i] == -1``, point i doesn't belong to any chunklet. When ``chunks[i] == j``, point i belongs to chunklet j. """ + if self.num_dims != 'deprecated': + warnings.warn('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead', + DeprecationWarning) X, chunks = self._prepare_inputs(X, chunks, ensure_min_samples=2) # PCA projection to remove noise and redundant information. @@ -145,12 +150,13 @@ class RCA_Supervised(RCA): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The learned linear transformation ``L``. """ - def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, - chunk_size=2, preprocessor=None): + def __init__(self, num_dims='deprecated', n_components=None, + pca_comps=None, num_chunks=100, chunk_size=2, + preprocessor=None): """Initialize the supervised version of `RCA`. `RCA_Supervised` creates chunks of similar points by first sampling a @@ -159,16 +165,23 @@ def __init__(self, num_dims=None, pca_comps=None, num_chunks=100, Parameters ---------- - num_dims : int, optional - embedding dimension (default: original dimension of data) + n_components : int or None, optional (default=None) + Dimensionality of reduced space (if None, defaults to dimension of X). + + num_dims : Not used + + .. deprecated:: 0.5.0 + `num_dims` was deprecated in version 0.5.0 and will + be removed in 0.6.0. Use `n_components` instead. + num_chunks: int, optional chunk_size: int, optional preprocessor : array-like, shape=(n_samples, n_features) or callable The preprocessor to call to get tuples from indices. If array-like, tuples will be formed like this: X[indices]. """ - RCA.__init__(self, num_dims=num_dims, pca_comps=pca_comps, - preprocessor=preprocessor) + RCA.__init__(self, num_dims=num_dims, n_components=n_components, + pca_comps=pca_comps, preprocessor=preprocessor) self.num_chunks = num_chunks self.chunk_size = chunk_size diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index b300b9ac..73eeefb7 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -139,7 +139,7 @@ class SDML(_BaseSDML, _PairsClassifierMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) @@ -187,7 +187,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin): Attributes ---------- - transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) + transformer_ : `numpy.ndarray`, shape=(n_components, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) """ @@ -247,7 +247,7 @@ def fit(self, X, y, random_state=np.random): if self.num_labeled != 'deprecated': warnings.warn('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0', DeprecationWarning) + ' removed in 0.6.0', DeprecationWarning) X, y = self._prepare_inputs(X, y, ensure_min_samples=2) num_constraints = self.num_constraints if num_constraints is None: diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index 06da087a..06c6a288 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -16,7 +16,7 @@ HAS_SKGGM = False else: HAS_SKGGM = True -from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, +from metric_learn import (LMNN, NCA, LFDA, Covariance, MLKR, MMC, RCA, LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised, MMC_Supervised, SDML, ITML) # Import this specially for testing. @@ -71,7 +71,7 @@ def test_deprecation_num_labeled(self): lsml_supervised = LSML_Supervised(num_labeled=np.inf) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0') + ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, lsml_supervised.fit, X, y) @@ -92,7 +92,7 @@ def test_deprecation_num_labeled(self): itml_supervised = ITML_Supervised(num_labeled=np.inf) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0') + ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) def test_deprecation_bounds(self): @@ -104,7 +104,7 @@ def test_deprecation_bounds(self): itml_supervised = ITML_Supervised(bounds=None) msg = ('"bounds" parameter from initialization is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0. Use the "bounds" parameter of this ' + ' removed in 0.6.0. Use the "bounds" parameter of this ' 'fit method instead.') assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y) @@ -411,7 +411,7 @@ def test_deprecation_num_labeled(self): balance_param=5e-5) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0') + ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y) def test_sdml_raises_warning_non_psd(self): @@ -519,13 +519,13 @@ def test_iris(self): n = self.iris_points.shape[0] # Without dimension reduction - nca = NCA(max_iter=(100000//n)) + nca = NCA(max_iter=(100000 // n)) nca.fit(self.iris_points, self.iris_labels) csep = class_separation(nca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) # With dimension reduction - nca = NCA(max_iter=(100000//n), num_dims=2) + nca = NCA(max_iter=(100000 // n), n_components=2) nca.fit(self.iris_points, self.iris_labels) csep = class_separation(nca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.20) @@ -565,7 +565,7 @@ def test_simple_example(self): """ X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) y = np.array([1, 0, 1, 0]) - nca = NCA(num_dims=2,) + nca = NCA(n_components=2,) nca.fit(X, y) Xansformed = nca.transform(X) np.testing.assert_equal(pairwise_distances(Xansformed).argsort()[:, 1], @@ -608,7 +608,7 @@ def test_singleton_class(self): A = np.zeros((X.shape[1], X.shape[1])) np.fill_diagonal(A, 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, num_dims=X.shape[1]) + nca = NCA(max_iter=30, n_components=X.shape[1]) nca.fit(X, y) assert_array_equal(nca.transformer_, A) @@ -621,14 +621,30 @@ def test_one_class(self): A = np.zeros((X.shape[1], X.shape[1])) np.fill_diagonal(A, 1. / (np.maximum(X.max(axis=0) - X.min(axis=0), EPS))) - nca = NCA(max_iter=30, num_dims=X.shape[1]) + nca = NCA(max_iter=30, n_components=X.shape[1]) nca.fit(X, y) assert_array_equal(nca.transformer_, A) +@pytest.mark.parametrize('num_dims', [None, 2]) +def test_deprecation_num_dims_nca(num_dims): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + nca = NCA(num_dims=num_dims) + msg = ('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead') + with pytest.warns(DeprecationWarning) as raised_warning: + nca.fit(X, y) + assert (str(raised_warning[0].message) == msg) + + class TestLFDA(MetricTestCase): def test_iris(self): - lfda = LFDA(k=2, num_dims=2) + lfda = LFDA(k=2, n_components=2) lfda.fit(self.iris_points, self.iris_labels) csep = class_separation(lfda.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.15) @@ -638,9 +654,25 @@ def test_iris(self): self.assertEqual(lfda.transformer_.shape, (2, 4)) +@pytest.mark.parametrize('num_dims', [None, 2]) +def test_deprecation_num_dims_lfda(num_dims): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + lfda = LFDA(num_dims=num_dims) + msg = ('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead') + with pytest.warns(DeprecationWarning) as raised_warning: + lfda.fit(X, y) + assert (str(raised_warning[0].message) == msg) + + class TestRCA(MetricTestCase): def test_iris(self): - rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) + rca = RCA_Supervised(n_components=2, num_chunks=30, chunk_size=2) rca.fit(self.iris_points, self.iris_labels) csep = class_separation(rca.transform(self.iris_points), self.iris_labels) self.assertLess(csep, 0.25) @@ -649,19 +681,44 @@ def test_feature_null_variance(self): X = np.hstack((self.iris_points, np.eye(len(self.iris_points), M=1))) # Apply PCA with the number of components - rca = RCA_Supervised(num_dims=2, pca_comps=3, num_chunks=30, chunk_size=2) + rca = RCA_Supervised(n_components=2, pca_comps=3, num_chunks=30, + chunk_size=2) rca.fit(X, self.iris_labels) csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) # Apply PCA with the minimum variance ratio - rca = RCA_Supervised(num_dims=2, pca_comps=0.95, num_chunks=30, + rca = RCA_Supervised(n_components=2, pca_comps=0.95, num_chunks=30, chunk_size=2) rca.fit(X, self.iris_labels) csep = class_separation(rca.transform(X), self.iris_labels) self.assertLess(csep, 0.30) +@pytest.mark.parametrize('num_dims', [None, 2]) +def test_deprecation_num_dims_rca(num_dims): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 + X, y = load_iris(return_X_y=True) + rca = RCA(num_dims=num_dims) + msg = ('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead') + with pytest.warns(DeprecationWarning) as raised_warning: + rca.fit(X, y) + assert (str(raised_warning[0].message) == msg) + + # we take a small number of chunks so that RCA works on iris + rca_supervised = RCA_Supervised(num_dims=num_dims, num_chunks=10) + msg = ('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead') + with pytest.warns(DeprecationWarning) as raised_warning: + rca_supervised.fit(X, y) + assert (str(raised_warning[0].message) == msg) + + class TestMLKR(MetricTestCase): def test_iris(self): mlkr = MLKR() @@ -693,6 +750,22 @@ def grad_fn(M): np.testing.assert_almost_equal(rel_diff, 0.) +@pytest.mark.parametrize('num_dims', [None, 2]) +def test_deprecation_num_dims_mlkr(num_dims): + # test that a deprecation message is thrown if num_labeled is set at + # initialization + # TODO: remove in v.0.6 + X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]]) + y = np.array([1, 0, 1, 0]) + mlkr = MLKR(num_dims=num_dims) + msg = ('"num_dims" parameter is not used.' + ' It has been deprecated in version 0.5.0 and will be' + ' removed in 0.6.0. Use "n_components" instead') + with pytest.warns(DeprecationWarning) as raised_warning: + mlkr.fit(X, y) + assert (str(raised_warning[0].message) == msg) + + class TestMMC(MetricTestCase): def test_iris(self): @@ -740,7 +813,7 @@ def test_deprecation_num_labeled(self): mmc_supervised = MMC_Supervised(num_labeled=np.inf) msg = ('"num_labeled" parameter is not used.' ' It has been deprecated in version 0.5.0 and will be' - 'removed in 0.6.0') + ' removed in 0.6.0') assert_warns_message(DeprecationWarning, msg, mmc_supervised.fit, X, y) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index e5f2e17b..7706b1e4 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -22,20 +22,22 @@ def test_lmnn(self): self.assertRegexpMatches( str(metric_learn.LMNN()), r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " - r"max_iter=1000,\s+min_iter=50, preprocessor=None, " - r"regularization=0.5, use_pca=True,\s+verbose=False\)") + r"max_iter=1000,\s+min_iter=50, n_components=None, " + r"num_dims='deprecated',\s+preprocessor=None, " + r"regularization=0.5, use_pca=True, verbose=False\)") def test_nca(self): self.assertEqual(remove_spaces(str(metric_learn.NCA())), remove_spaces( - "NCA(max_iter=100, num_dims=None, preprocessor=None, " + "NCA(max_iter=100, n_components=None, " + "num_dims='deprecated', preprocessor=None, " "tol=None, verbose=False)")) def test_lfda(self): self.assertEqual(remove_spaces(str(metric_learn.LFDA())), remove_spaces( "LFDA(embedding_type='weighted', k=None, " - "num_dims=None, " + "n_components=None, num_dims='deprecated'," "preprocessor=None)")) def test_itml(self): @@ -79,19 +81,23 @@ def test_sdml(self): def test_rca(self): self.assertEqual(remove_spaces(str(metric_learn.RCA())), - remove_spaces("RCA(num_dims=None, pca_comps=None, " + remove_spaces("RCA(n_components=None, " + "num_dims='deprecated', " + "pca_comps=None, " "preprocessor=None)")) self.assertEqual(remove_spaces(str(metric_learn.RCA_Supervised())), remove_spaces( - "RCA_Supervised(chunk_size=2, num_chunks=100, " - "num_dims=None, pca_comps=None,\n " + "RCA_Supervised(chunk_size=2, " + "n_components=None, num_chunks=100, " + "num_dims='deprecated', pca_comps=None, " "preprocessor=None)")) def test_mlkr(self): self.assertEqual(remove_spaces(str(metric_learn.MLKR())), remove_spaces( - "MLKR(A0=None, max_iter=1000, num_dims=None, " - "preprocessor=None, tol=None,\n verbose=False)")) + "MLKR(A0=None, max_iter=1000, n_components=None, " + "num_dims='deprecated', " + "preprocessor=None, tol=None, verbose=False)")) def test_mmc(self): self.assertEqual(remove_spaces(str(metric_learn.MMC())), @@ -183,5 +189,42 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): assert len(record) == 0 +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_n_components(estimator, build_dataset): + """Check that estimators that have a n_components parameters can use it + and that it actually works as expected""" + input_data, labels, _, X = build_dataset() + model = clone(estimator) + + if hasattr(model, 'n_components'): + set_random_state(model) + model.set_params(n_components=None) + model.fit(input_data, labels) + assert model.transformer_.shape == (X.shape[1], X.shape[1]) + + model = clone(estimator) + set_random_state(model) + model.set_params(n_components=X.shape[1] - 1) + model.fit(input_data, labels) + assert model.transformer_.shape == (X.shape[1] - 1, X.shape[1]) + + model = clone(estimator) + set_random_state(model) + model.set_params(n_components=X.shape[1] + 1) + with pytest.raises(ValueError) as expected_err: + model.fit(input_data, labels) + assert (str(expected_err.value) == + 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) + + model = clone(estimator) + set_random_state(model) + model.set_params(n_components=0) + with pytest.raises(ValueError) as expected_err: + model.fit(input_data, labels) + assert (str(expected_err.value) == + 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index b85e9273..5e8a87f4 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -88,36 +88,34 @@ def test_nca(self): assert_array_almost_equal(res_1, res_2) def test_lfda(self): - lfda = LFDA(k=2, num_dims=2) + lfda = LFDA(k=2, n_components=2) lfda.fit(self.X, self.y) res_1 = lfda.transform(self.X) - lfda = LFDA(k=2, num_dims=2) + lfda = LFDA(k=2, n_components=2) res_2 = lfda.fit_transform(self.X, self.y) # signs may be flipped, that's okay - if np.sign(res_1[0,0]) != np.sign(res_2[0,0]): - res_2 *= -1 - assert_array_almost_equal(res_1, res_2) + assert_array_almost_equal(abs(res_1), abs(res_2)) def test_rca_supervised(self): seed = np.random.RandomState(1234) - rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) + rca = RCA_Supervised(n_components=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) res_1 = rca.transform(self.X) seed = np.random.RandomState(1234) - rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) + rca = RCA_Supervised(n_components=2, num_chunks=30, chunk_size=2) res_2 = rca.fit_transform(self.X, self.y, random_state=seed) assert_array_almost_equal(res_1, res_2) def test_mlkr(self): - mlkr = MLKR(num_dims=2) + mlkr = MLKR(n_components=2) mlkr.fit(self.X, self.y) res_1 = mlkr.transform(self.X) - mlkr = MLKR(num_dims=2) + mlkr = MLKR(n_components=2) res_2 = mlkr.fit_transform(self.X, self.y) assert_array_almost_equal(res_1, res_2) diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index 15bf1aed..e7fa5b17 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -137,11 +137,8 @@ def test_embed_dim(estimator, build_dataset): model.score_pairs(model.transform(X[0, :])) assert str(raised_error.value) == err_msg # we test that the shape is also OK when doing dimensionality reduction - if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}: - # TODO: - # avoid this enumeration and rather test if hasattr n_components - # as soon as we have made the arguments names as such (issue #167) - model.set_params(num_dims=2) + if hasattr(model, 'n_components'): + model.set_params(n_components=2) model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert model.transform(X).shape == (X.shape[0], 2) # assert that ValueError is thrown if input shape is 1D diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 091c56e2..6b451aee 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -352,8 +352,8 @@ def test_dict_unchanged(estimator, build_dataset, with_preprocessor): to_transform) = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) - if hasattr(estimator, "num_dims"): - estimator.num_dims = 1 + if hasattr(estimator, "n_components"): + estimator.n_components = 1 estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) def check_dict(): @@ -381,8 +381,8 @@ def test_dont_overwrite_parameters(estimator, build_dataset, input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) - if hasattr(estimator, "num_dims"): - estimator.num_dims = 1 + if hasattr(estimator, "n_components"): + estimator.n_components = 1 dict_before_fit = estimator.__dict__.copy() estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) diff --git a/test/test_transformer_metric_conversion.py b/test/test_transformer_metric_conversion.py index 4328320d..0139f632 100644 --- a/test/test_transformer_metric_conversion.py +++ b/test/test_transformer_metric_conversion.py @@ -63,20 +63,20 @@ def test_nca(self): assert_array_almost_equal(L.T.dot(L), nca.get_mahalanobis_matrix()) def test_lfda(self): - lfda = LFDA(k=2, num_dims=2) + lfda = LFDA(k=2, n_components=2) lfda.fit(self.X, self.y) L = lfda.transformer_ assert_array_almost_equal(L.T.dot(L), lfda.get_mahalanobis_matrix()) def test_rca_supervised(self): seed = np.random.RandomState(1234) - rca = RCA_Supervised(num_dims=2, num_chunks=30, chunk_size=2) + rca = RCA_Supervised(n_components=2, num_chunks=30, chunk_size=2) rca.fit(self.X, self.y, random_state=seed) L = rca.transformer_ assert_array_almost_equal(L.T.dot(L), rca.get_mahalanobis_matrix()) def test_mlkr(self): - mlkr = MLKR(num_dims=2) + mlkr = MLKR(n_components=2) mlkr.fit(self.X, self.y) L = mlkr.transformer_ assert_array_almost_equal(L.T.dot(L), mlkr.get_mahalanobis_matrix()) diff --git a/test/test_utils.py b/test/test_utils.py index 6441fac6..08415a76 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,7 +10,7 @@ from metric_learn._util import (check_input, make_context, preprocess_tuples, make_name, preprocess_points, check_collapsed_pairs, validate_vector, - _check_sdp_from_eigen, + _check_sdp_from_eigen, _check_n_components, check_y_valid_values_for_pairs) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, @@ -867,9 +867,9 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): formed_points_to_transform = dataset_formed.to_transform (indices_train, indices_test, y_train, y_test, formed_train, formed_test) = train_test_split(dataset_indices.data, - dataset_indices.target, - dataset_formed.data, - random_state=SEED) + dataset_indices.target, + dataset_formed.data, + random_state=SEED) def make_random_state(estimator): rs = {} @@ -1008,6 +1008,24 @@ def test_check_sdp_from_eigen_positive_err_messages(): _check_sdp_from_eigen(w, None) +def test__check_n_components(): + """Checks that n_components returns what is expected + (including the errors)""" + dim = _check_n_components(5, None) + assert dim == 5 + + dim = _check_n_components(5, 3) + assert dim == 3 + + with pytest.raises(ValueError) as expected_err: + _check_n_components(5, 10) + assert str(expected_err.value) == 'Invalid n_components, must be in [1, 5]' + + with pytest.raises(ValueError) as expected_err: + _check_n_components(5, 0) + assert str(expected_err.value) == 'Invalid n_components, must be in [1, 5]' + + @pytest.mark.unit @pytest.mark.parametrize('wrong_labels', [[0.5, 0.6, 0.7, 0.8, 0.9],