Skip to content

Commit 130cbad

Browse files
wdevazelhesperimosocordiae
authored andcommitted
[MRG] Uniformize initialization for all algorithms (#195)
* initiate PR * Revert "initiate PR" This reverts commit a2ae9e1. * FEAT: uniformize init for NCA and RCA * Let the check of num_dims be done in the other PR * Add metric initialization for algorithms that learn a mahalanobis matrix * Add initialization for MLKR * FIX: fix error message for dimension * FIX fix StringRepr for MLKR * FIX tests by reshaping to the right dataset size * Remove lda in docstring of MLKR * MAINT: Add deprecation for previous initializations * Update tests with new initialization * Make random init for mahalanobis metric generate an SPD matrix * Ensure the input mahalanobis metric initialization is symmetric, and say it should be SPD * various fixes * MAINT: various refactoring - MLKR: update default test init - SDML: refactor prior_inv * FIX fix default covariance for SDML in tests * Enhance docstring * Set random state for SDML * Fix merge remove_spaces that was forgotten * Fix indent * XP: try to change the way we choose n_components to see if it fixes the test * Revert "XP: try to change the way we choose n_components to see if it fixes the test" This reverts commit e86b61b. * Be more tolerant in test * Add test for singular covariance matrix * Fix test_singular_covariance_init * DOC: update docstring saying pseudo-inverse * Revert "Fix test_singular_covariance_init" This reverts commit d2cc7ce. * Ensure definiteness before returning the inverse * wip deal with non definiteness * Rename init to prior for SDML and LSML * Update error messages with either prior or init * Remove message * A few nitpicks * PEP8 errors + change init in test * STY: PEP8 fixes * Address and remove TODOs * Replace init by prior for ITML * TST: fix ITML test with init changed into prior * Add precision for MMC * Add ChangedBehaviorWarning for the algorithms that changed * Address #195 (review) * Remove the warnings check since we now have a ChangedBehaviorWarning * Be more precise: it should not raise any ConvergenceWarningError * Address #195 (review) * FIX remaining comment * TST: update test error message * Improve readability * Address #195 (review) * TST: Fix docsting lmnn * Fix warning messages * Fix warnings messages changed
1 parent 3899653 commit 130cbad

18 files changed

+1626
-223
lines changed

bench/benchmarks/iris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
'LMNN': metric_learn.LMNN(k=5, learn_rate=1e-6, verbose=False),
1111
'LSML_Supervised': metric_learn.LSML_Supervised(num_constraints=200),
1212
'MLKR': metric_learn.MLKR(),
13-
'NCA': metric_learn.NCA(max_iter=700, num_dims=2),
13+
'NCA': metric_learn.NCA(max_iter=700, n_components=2),
1414
'RCA_Supervised': metric_learn.RCA_Supervised(dim=2, num_chunks=30,
1515
chunk_size=2),
1616
'SDML_Supervised': metric_learn.SDML_Supervised(num_constraints=1500),

metric_learn/_util.py

Lines changed: 328 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1-
import warnings
21
import numpy as np
2+
import scipy
33
import six
44
from numpy.linalg import LinAlgError
5+
from sklearn.datasets import make_spd_matrix
6+
from sklearn.decomposition import PCA
57
from sklearn.utils import check_array
6-
from sklearn.utils.validation import check_X_y
7-
from metric_learn.exceptions import PreprocessorError
8+
from sklearn.utils.validation import check_X_y, check_random_state
9+
from .exceptions import PreprocessorError, NonPSDError
10+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
11+
from scipy.linalg import pinvh
12+
import sys
13+
import time
814

915
# hack around lack of axis kwarg in older numpy versions
1016
try:
@@ -335,28 +341,38 @@ def check_collapsed_pairs(pairs):
335341
def _check_sdp_from_eigen(w, tol=None):
336342
"""Checks if some of the eigenvalues given are negative, up to a tolerance
337343
level, with a default value of the tolerance depending on the eigenvalues.
344+
It also returns whether the matrix is positive definite, up to the above
345+
tolerance.
338346
339347
Parameters
340348
----------
341349
w : array-like, shape=(n_eigenvalues,)
342350
Eigenvalues to check for non semidefinite positiveness.
343351
344352
tol : positive `float`, optional
345-
Negative eigenvalues above - tol are considered zero. If
353+
Absolute eigenvalues below tol are considered zero. If
346354
tol is None, and eps is the epsilon value for datatype of w, then tol
347-
is set to w.max() * len(w) * eps.
355+
is set to abs(w).max() * len(w) * eps.
356+
357+
Returns
358+
-------
359+
is_definite : bool
360+
Whether the matrix is positive definite or not.
348361
349362
See Also
350363
--------
351364
np.linalg.matrix_rank for more details on the choice of tolerance (the same
352365
strategy is applied here)
353366
"""
354367
if tol is None:
355-
tol = w.max() * len(w) * np.finfo(w.dtype).eps
368+
tol = np.abs(w).max() * len(w) * np.finfo(w.dtype).eps
356369
if tol < 0:
357370
raise ValueError("tol should be positive.")
358371
if any(w < - tol):
359-
raise ValueError("Matrix is not positive semidefinite (PSD).")
372+
raise NonPSDError()
373+
if any(abs(w) < tol):
374+
return False
375+
return True
360376

361377

362378
def transformer_from_metric(metric, tol=None):
@@ -413,6 +429,311 @@ def validate_vector(u, dtype=None):
413429
return u
414430

415431

432+
def _initialize_transformer(n_components, input, y=None, init='auto',
433+
verbose=False, random_state=None,
434+
has_classes=True):
435+
"""Returns the initial transformer to be used depending on the arguments.
436+
437+
Parameters
438+
----------
439+
n_components : int
440+
The number of components to take. (Note: it should have been checked
441+
before, meaning it should not be None and it should be a value in
442+
[1, X.shape[1]])
443+
444+
input : array-like
445+
The input samples (can be tuples or regular samples).
446+
447+
y : array-like or None
448+
The input labels (or not if there are no labels).
449+
450+
init : string or numpy array, optional (default='auto')
451+
Initialization of the linear transformation. Possible options are
452+
'auto', 'pca', 'lda', 'identity', 'random', and a numpy array of shape
453+
(n_features_a, n_features_b).
454+
455+
'auto'
456+
Depending on ``n_components``, the most reasonable initialization
457+
will be chosen. If ``n_components <= n_classes`` we use 'lda' (see
458+
the description of 'lda' init), as it uses labels information. If
459+
not, but ``n_components < min(n_features, n_samples)``, we use 'pca',
460+
as it projects data onto meaningful directions (those of higher
461+
variance). Otherwise, we just use 'identity'.
462+
463+
'pca'
464+
``n_components`` principal components of the inputs passed
465+
to :meth:`fit` will be used to initialize the transformation.
466+
(See `sklearn.decomposition.PCA`)
467+
468+
'lda'
469+
``min(n_components, n_classes)`` most discriminative
470+
components of the inputs passed to :meth:`fit` will be used to
471+
initialize the transformation. (If ``n_components > n_classes``,
472+
the rest of the components will be zero.) (See
473+
`sklearn.discriminant_analysis.LinearDiscriminantAnalysis`).
474+
This initialization is possible only if `has_classes == True`.
475+
476+
'identity'
477+
The identity matrix. If ``n_components`` is strictly smaller than the
478+
dimensionality of the inputs passed to :meth:`fit`, the identity
479+
matrix will be truncated to the first ``n_components`` rows.
480+
481+
'random'
482+
The initial transformation will be a random array of shape
483+
`(n_components, n_features)`. Each value is sampled from the
484+
standard normal distribution.
485+
486+
numpy array
487+
n_features_b must match the dimensionality of the inputs passed to
488+
:meth:`fit` and n_features_a must be less than or equal to that.
489+
If ``n_components`` is not None, n_features_a must match it.
490+
491+
verbose : bool
492+
Whether to print the details of the initialization or not.
493+
494+
random_state : int or `numpy.RandomState` or None, optional (default=None)
495+
A pseudo random number generator object or a seed for it if int. If
496+
``init='random'``, ``random_state`` is used to initialize the random
497+
transformation. If ``init='pca'``, ``random_state`` is passed as an
498+
argument to PCA when initializing the transformation.
499+
500+
has_classes : bool (default=True)
501+
Whether the labels are in fact classes. If true, this will allow to use
502+
the 'lda' initialization.
503+
504+
Returns
505+
-------
506+
init_transformer : `numpy.ndarray`
507+
The initial transformer to use.
508+
"""
509+
# if we are doing a regression we cannot use lda:
510+
n_features = input.shape[-1]
511+
authorized_inits = ['auto', 'pca', 'identity', 'random']
512+
if has_classes:
513+
authorized_inits.append('lda')
514+
515+
if isinstance(init, np.ndarray):
516+
# we copy the array, so that if we update the metric, we don't want to
517+
# update the init
518+
init = check_array(init, copy=True)
519+
520+
# Assert that init.shape[1] = X.shape[1]
521+
if init.shape[1] != n_features:
522+
raise ValueError('The input dimensionality ({}) of the given '
523+
'linear transformation `init` must match the '
524+
'dimensionality of the given inputs `X` ({}).'
525+
.format(init.shape[1], n_features))
526+
527+
# Assert that init.shape[0] <= init.shape[1]
528+
if init.shape[0] > init.shape[1]:
529+
raise ValueError('The output dimensionality ({}) of the given '
530+
'linear transformation `init` cannot be '
531+
'greater than its input dimensionality ({}).'
532+
.format(init.shape[0], init.shape[1]))
533+
534+
# Assert that self.n_components = init.shape[0]
535+
if n_components != init.shape[0]:
536+
raise ValueError('The preferred dimensionality of the '
537+
'projected space `n_components` ({}) does'
538+
' not match the output dimensionality of '
539+
'the given linear transformation '
540+
'`init` ({})!'
541+
.format(n_components,
542+
init.shape[0]))
543+
elif init not in authorized_inits:
544+
raise ValueError(
545+
"`init` must be '{}' "
546+
"or a numpy array of shape (n_components, n_features)."
547+
.format("', '".join(authorized_inits)))
548+
549+
random_state = check_random_state(random_state)
550+
if isinstance(init, np.ndarray):
551+
return init
552+
n_samples = input.shape[0]
553+
if init == 'auto':
554+
if has_classes:
555+
n_classes = len(np.unique(y))
556+
else:
557+
n_classes = -1
558+
init = _auto_select_init(has_classes, n_features, n_samples, n_components,
559+
n_classes)
560+
if init == 'identity':
561+
return np.eye(n_components, input.shape[-1])
562+
elif init == 'random':
563+
return random_state.randn(n_components, input.shape[-1])
564+
elif init in {'pca', 'lda'}:
565+
init_time = time.time()
566+
if init == 'pca':
567+
pca = PCA(n_components=n_components,
568+
random_state=random_state)
569+
if verbose:
570+
print('Finding principal components... ')
571+
sys.stdout.flush()
572+
pca.fit(input)
573+
transformation = pca.components_
574+
elif init == 'lda':
575+
lda = LinearDiscriminantAnalysis(n_components=n_components)
576+
if verbose:
577+
print('Finding most discriminative components... ')
578+
sys.stdout.flush()
579+
lda.fit(input, y)
580+
transformation = lda.scalings_.T[:n_components]
581+
if verbose:
582+
print('done in {:5.2f}s'.format(time.time() - init_time))
583+
return transformation
584+
585+
586+
def _auto_select_init(has_classes, n_features, n_samples, n_components,
587+
n_classes):
588+
if has_classes and n_components <= min(n_features, n_classes - 1):
589+
init = 'lda'
590+
elif n_components < min(n_features, n_samples):
591+
init = 'pca'
592+
else:
593+
init = 'identity'
594+
return init
595+
596+
597+
def _initialize_metric_mahalanobis(input, init='identity', random_state=None,
598+
return_inverse=False, strict_pd=False,
599+
matrix_name='matrix'):
600+
"""Returns a PSD matrix that can be used as a prior or an initialization
601+
for the Mahalanobis distance
602+
603+
Parameters
604+
----------
605+
input : array-like
606+
The input samples (can be tuples or regular samples).
607+
608+
init : string or numpy array, optional (default='identity')
609+
Specification for the matrix to initialize. Possible options are
610+
'identity', 'covariance', 'random', and a numpy array of shape
611+
(n_features, n_features).
612+
613+
'identity'
614+
An identity matrix of shape (n_features, n_features).
615+
616+
'covariance'
617+
The (pseudo-)inverse covariance matrix (raises an error if the
618+
covariance matrix is not definite and `strict_pd == True`)
619+
620+
'random'
621+
A random positive definite (PD) matrix of shape
622+
`(n_features, n_features)`, generated using
623+
`sklearn.datasets.make_spd_matrix`.
624+
625+
numpy array
626+
A PSD matrix (or strictly PD if strict_pd==True) of
627+
shape (n_features, n_features), that will be used as such to
628+
initialize the metric, or set the prior.
629+
630+
random_state : int or `numpy.RandomState` or None, optional (default=None)
631+
A pseudo random number generator object or a seed for it if int. If
632+
``init='random'``, ``random_state`` is used to set the random Mahalanobis
633+
matrix. If ``init='pca'``, ``random_state`` is passed as an
634+
argument to PCA when initializing the matrix.
635+
636+
return_inverse : bool, optional (default=False)
637+
Whether to return the inverse of the specified matrix. This
638+
can be sometimes useful. It will return the pseudo-inverse (which is the
639+
same as the inverse if the matrix is definite (i.e. invertible)). If
640+
`strict_pd == True` and the matrix is not definite, it will return an
641+
error.
642+
643+
strict_pd : bool, optional (default=False)
644+
Whether to enforce that the provided matrix is definite (in addition to
645+
being PSD).
646+
647+
param_name : str, optional (default='matrix')
648+
The name of the matrix used (example: 'init', 'prior'). Will be used in
649+
error messages.
650+
651+
Returns
652+
-------
653+
M, or (M, M_inv) : `numpy.ndarray`
654+
The initial matrix to use M, and its inverse if `return_inverse=True`.
655+
"""
656+
n_features = input.shape[-1]
657+
if isinstance(init, np.ndarray):
658+
# we copy the array, so that if we update the metric, we don't want to
659+
# update the init
660+
init = check_array(init, copy=True)
661+
662+
# Assert that init.shape[1] = n_features
663+
if init.shape != (n_features,) * 2:
664+
raise ValueError('The input dimensionality {} of the given '
665+
'mahalanobis matrix `{}` must match the '
666+
'dimensionality of the given inputs ({}).'
667+
.format(init.shape, matrix_name, n_features))
668+
669+
# Assert that the matrix is symmetric
670+
if not np.allclose(init, init.T):
671+
raise ValueError("`{}` is not symmetric.".format(matrix_name))
672+
673+
elif init not in ['identity', 'covariance', 'random']:
674+
raise ValueError(
675+
"`{}` must be 'identity', 'covariance', 'random' "
676+
"or a numpy array of shape (n_features, n_features)."
677+
.format(matrix_name))
678+
679+
random_state = check_random_state(random_state)
680+
M = init
681+
if isinstance(init, np.ndarray):
682+
s, u = scipy.linalg.eigh(init)
683+
init_is_definite = _check_sdp_from_eigen(s)
684+
if strict_pd and not init_is_definite:
685+
raise LinAlgError("You should provide a strictly positive definite "
686+
"matrix as `{}`. This one is not definite. Try another"
687+
" {}, or an algorithm that does not "
688+
"require the {} to be strictly positive definite."
689+
.format(*((matrix_name,) * 3)))
690+
if return_inverse:
691+
M_inv = np.dot(u / s, u.T)
692+
return M, M_inv
693+
else:
694+
return M
695+
elif init == 'identity':
696+
M = np.eye(n_features, n_features)
697+
if return_inverse:
698+
M_inv = M.copy()
699+
return M, M_inv
700+
else:
701+
return M
702+
elif init == 'covariance':
703+
if input.ndim == 3:
704+
# if the input are tuples, we need to form an X by deduplication
705+
X = np.vstack({tuple(row) for row in input.reshape(-1, n_features)})
706+
else:
707+
X = input
708+
# atleast2d is necessary to deal with scalar covariance matrices
709+
M_inv = np.atleast_2d(np.cov(X, rowvar=False))
710+
s, u = scipy.linalg.eigh(M_inv)
711+
cov_is_definite = _check_sdp_from_eigen(s)
712+
if strict_pd and not cov_is_definite:
713+
raise LinAlgError("Unable to get a true inverse of the covariance "
714+
"matrix since it is not definite. Try another "
715+
"`{}`, or an algorithm that does not "
716+
"require the `{}` to be strictly positive definite."
717+
.format(*((matrix_name,) * 2)))
718+
M = np.dot(u / s, u.T)
719+
if return_inverse:
720+
return M, M_inv
721+
else:
722+
return M
723+
elif init == 'random':
724+
# we need to create a random symmetric matrix
725+
M = make_spd_matrix(n_features, random_state=random_state)
726+
if return_inverse:
727+
# we use pinvh even if we know the matrix is definite, just because
728+
# we need the returned matrix to be symmetric (and sometimes
729+
# np.linalg.inv returns not symmetric inverses of symmetric matrices)
730+
# TODO: there might be a more efficient method to do so
731+
M_inv = pinvh(M)
732+
return M, M_inv
733+
else:
734+
return M
735+
736+
416737
def _check_n_components(n_features, n_components):
417738
"""Checks that n_components is less than n_features and deal with the None
418739
case"""

metric_learn/covariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Covariance(MahalanobisMixin, TransformerMixin):
2222
2323
Attributes
2424
----------
25-
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
25+
transformer_ : `numpy.ndarray`, shape=(n_features, n_features)
2626
The linear transformation ``L`` deduced from the learned Mahalanobis
2727
metric (See function `transformer_from_metric`.)
2828
"""

0 commit comments

Comments
 (0)