Skip to content

Commit 6630d0a

Browse files
author
William de Vazelhes
committed
FIX make some tests work
1 parent 047191b commit 6630d0a

File tree

2 files changed

+75
-71
lines changed

2 files changed

+75
-71
lines changed

metric_learn/_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,5 +413,5 @@ def _check_n_components(n_features, n_components):
413413
if n_components is None:
414414
return n_features
415415
if 0 < n_components <= n_features:
416-
return n_features
416+
return n_components
417417
raise ValueError('Invalid n_components, must be in [1, %d]' % n_features)

test/metric_learn_test.py

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -494,21 +494,22 @@ def test_one_class(self):
494494
nca.fit(X, y)
495495
assert_array_equal(nca.transformer_, A)
496496

497-
@pytest.mark.parametrize('num_dims', [None, 2])
498-
def test_deprecation_num_dims(self, num_dims):
499-
# test that a deprecation message is thrown if num_labeled is set at
500-
# initialization
501-
# TODO: remove in v.0.6
502-
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
503-
y = np.array([1, 0, 1, 0])
504-
nca = NCA(num_dims=num_dims)
505-
msg = ('"num_dims" parameter is not used.'
506-
' It has been deprecated in version 0.5.0 and will be'
507-
'removed in 0.6.0. Use "n_components" instead',
508-
DeprecationWarning)
509-
with pytest.warns(DeprecationWarning) as raised_warning:
510-
nca.fit(X, y)
511-
assert (str(raised_warning[0].message) == msg)
497+
498+
@pytest.mark.parametrize('num_dims', [None, 2])
499+
def test_deprecation_num_dims_nca(num_dims):
500+
# test that a deprecation message is thrown if num_labeled is set at
501+
# initialization
502+
# TODO: remove in v.0.6
503+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
504+
y = np.array([1, 0, 1, 0])
505+
nca = NCA(num_dims=num_dims)
506+
msg = ('"num_dims" parameter is not used.'
507+
' It has been deprecated in version 0.5.0 and will be'
508+
'removed in 0.6.0. Use "n_components" instead',
509+
DeprecationWarning)
510+
with pytest.warns(DeprecationWarning) as raised_warning:
511+
nca.fit(X, y)
512+
assert (str(raised_warning[0].message) == msg)
512513

513514

514515
class TestLFDA(MetricTestCase):
@@ -522,21 +523,22 @@ def test_iris(self):
522523
self.assertEqual(lfda.get_mahalanobis_matrix().shape, (4, 4))
523524
self.assertEqual(lfda.transformer_.shape, (2, 4))
524525

525-
@pytest.mark.parametrize('num_dims', [None, 2])
526-
def test_deprecation_num_dims(self, num_dims):
527-
# test that a deprecation message is thrown if num_labeled is set at
528-
# initialization
529-
# TODO: remove in v.0.6
530-
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
531-
y = np.array([1, 0, 1, 0])
532-
lfda = LFDA(num_dims=num_dims)
533-
msg = ('"num_dims" parameter is not used.'
534-
' It has been deprecated in version 0.5.0 and will be'
535-
'removed in 0.6.0. Use "n_components" instead',
536-
DeprecationWarning)
537-
with pytest.warns(DeprecationWarning) as raised_warning:
538-
lfda.fit(X, y)
539-
assert (str(raised_warning[0].message) == msg)
526+
527+
@pytest.mark.parametrize('num_dims', [None, 2])
528+
def test_deprecation_num_dims_lfda(num_dims):
529+
# test that a deprecation message is thrown if num_labeled is set at
530+
# initialization
531+
# TODO: remove in v.0.6
532+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
533+
y = np.array([1, 0, 1, 0])
534+
lfda = LFDA(num_dims=num_dims)
535+
msg = ('"num_dims" parameter is not used.'
536+
' It has been deprecated in version 0.5.0 and will be'
537+
'removed in 0.6.0. Use "n_components" instead',
538+
DeprecationWarning)
539+
with pytest.warns(DeprecationWarning) as raised_warning:
540+
lfda.fit(X, y)
541+
assert (str(raised_warning[0].message) == msg)
540542

541543

542544
class TestRCA(MetricTestCase):
@@ -546,31 +548,6 @@ def test_iris(self):
546548
csep = class_separation(rca.transform(self.iris_points), self.iris_labels)
547549
self.assertLess(csep, 0.25)
548550

549-
@pytest.mark.parametrize('num_dims', [None, 2])
550-
def test_deprecation_num_dims(self, num_dims):
551-
# test that a deprecation message is thrown if num_labeled is set at
552-
# initialization
553-
# TODO: remove in v.0.6
554-
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
555-
y = np.array([1, 0, 1, 0])
556-
rca = RCA(num_dims=num_dims)
557-
msg = ('"num_dims" parameter is not used.'
558-
' It has been deprecated in version 0.5.0 and will be'
559-
'removed in 0.6.0. Use "n_components" instead',
560-
DeprecationWarning)
561-
with pytest.warns(DeprecationWarning) as raised_warning:
562-
rca.fit(X, y)
563-
assert (str(raised_warning[0].message) == msg)
564-
565-
rca_supervised = RCA_Supervised(num_dims=num_dims)
566-
msg = ('"num_dims" parameter is not used.'
567-
' It has been deprecated in version 0.5.0 and will be'
568-
'removed in 0.6.0. Use "n_components" instead',
569-
DeprecationWarning)
570-
with pytest.warns(DeprecationWarning) as raised_warning:
571-
rca_supervised.fit(X, y)
572-
assert (str(raised_warning[0].message) == msg)
573-
574551
def test_feature_null_variance(self):
575552
X = np.hstack((self.iris_points, np.eye(len(self.iris_points), M=1)))
576553

@@ -589,6 +566,32 @@ def test_feature_null_variance(self):
589566
self.assertLess(csep, 0.30)
590567

591568

569+
@pytest.mark.parametrize('num_dims', [None, 2])
570+
def test_deprecation_num_dims_rca(num_dims):
571+
# test that a deprecation message is thrown if num_labeled is set at
572+
# initialization
573+
# TODO: remove in v.0.6
574+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
575+
y = np.array([1, 0, 1, 0])
576+
rca = RCA(num_dims=num_dims)
577+
msg = ('"num_dims" parameter is not used.'
578+
' It has been deprecated in version 0.5.0 and will be'
579+
'removed in 0.6.0. Use "n_components" instead',
580+
DeprecationWarning)
581+
with pytest.warns(DeprecationWarning) as raised_warning:
582+
rca.fit(X, y)
583+
assert (str(raised_warning[0].message) == msg)
584+
585+
rca_supervised = RCA_Supervised(num_dims=num_dims)
586+
msg = ('"num_dims" parameter is not used.'
587+
' It has been deprecated in version 0.5.0 and will be'
588+
'removed in 0.6.0. Use "n_components" instead',
589+
DeprecationWarning)
590+
with pytest.warns(DeprecationWarning) as raised_warning:
591+
rca_supervised.fit(X, y)
592+
assert (str(raised_warning[0].message) == msg)
593+
594+
592595
class TestMLKR(MetricTestCase):
593596
def test_iris(self):
594597
mlkr = MLKR()
@@ -619,21 +622,22 @@ def grad_fn(M):
619622
rel_diff = check_grad(fun, grad_fn, M.ravel()) / np.linalg.norm(grad_fn(M))
620623
np.testing.assert_almost_equal(rel_diff, 0.)
621624

622-
@pytest.mark.parametrize('num_dims', [None, 2])
623-
def test_deprecation_num_dims(self, num_dims):
624-
# test that a deprecation message is thrown if num_labeled is set at
625-
# initialization
626-
# TODO: remove in v.0.6
627-
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
628-
y = np.array([1, 0, 1, 0])
629-
mlkr = MLKR(num_dims=num_dims)
630-
msg = ('"num_dims" parameter is not used.'
631-
' It has been deprecated in version 0.5.0 and will be'
632-
'removed in 0.6.0. Use "n_components" instead',
633-
DeprecationWarning)
634-
with pytest.warns(DeprecationWarning) as raised_warning:
635-
mlkr.fit(X, y)
636-
assert (str(raised_warning[0].message) == msg)
625+
626+
@pytest.mark.parametrize('num_dims', [None, 2])
627+
def test_deprecation_num_dims_mlkr(num_dims):
628+
# test that a deprecation message is thrown if num_labeled is set at
629+
# initialization
630+
# TODO: remove in v.0.6
631+
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
632+
y = np.array([1, 0, 1, 0])
633+
mlkr = MLKR(num_dims=num_dims)
634+
msg = ('"num_dims" parameter is not used.'
635+
' It has been deprecated in version 0.5.0 and will be'
636+
'removed in 0.6.0. Use "n_components" instead',
637+
DeprecationWarning)
638+
with pytest.warns(DeprecationWarning) as raised_warning:
639+
mlkr.fit(X, y)
640+
assert (str(raised_warning[0].message) == msg)
637641

638642

639643
class TestMMC(MetricTestCase):

0 commit comments

Comments
 (0)