|
1 | 1 | import pytest
|
| 2 | +import re |
2 | 3 | import unittest
|
3 | 4 | import metric_learn
|
4 | 5 | import numpy as np
|
|
7 | 8 | from test.test_utils import ids_metric_learners, metric_learners
|
8 | 9 |
|
9 | 10 |
|
| 11 | +def remove_spaces(s): |
| 12 | + return re.sub('\s+', '', s) |
| 13 | + |
| 14 | + |
10 | 15 | class TestStringRepr(unittest.TestCase):
|
11 | 16 |
|
12 | 17 | def test_covariance(self):
|
13 |
| - self.assertEqual(str(metric_learn.Covariance()), |
14 |
| - "Covariance(preprocessor=None)") |
| 18 | + self.assertEqual(remove_spaces(str(metric_learn.Covariance())), |
| 19 | + remove_spaces("Covariance(preprocessor=None)")) |
15 | 20 |
|
16 | 21 | def test_lmnn(self):
|
17 | 22 | self.assertRegexpMatches(
|
18 |
| - str(metric_learn.LMNN()), |
19 |
| - r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " |
20 |
| - r"max_iter=1000,\n min_iter=50, preprocessor=None, " |
21 |
| - r"regularization=0.5, use_pca=True,\n verbose=False\)") |
| 23 | + str(metric_learn.LMNN()), |
| 24 | + r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, " |
| 25 | + r"max_iter=1000,\s+min_iter=50, preprocessor=None, " |
| 26 | + r"regularization=0.5, use_pca=True,\s+verbose=False\)") |
22 | 27 |
|
23 | 28 | def test_nca(self):
|
24 |
| - self.assertEqual(str(metric_learn.NCA()), |
25 |
| - "NCA(max_iter=100, num_dims=None, preprocessor=None, " |
26 |
| - "tol=None, verbose=False)") |
| 29 | + self.assertEqual(remove_spaces(str(metric_learn.NCA())), |
| 30 | + remove_spaces( |
| 31 | + "NCA(max_iter=100, num_dims=None, preprocessor=None, " |
| 32 | + "tol=None, verbose=False)")) |
27 | 33 |
|
28 | 34 | def test_lfda(self):
|
29 |
| - self.assertEqual(str(metric_learn.LFDA()), |
30 |
| - "LFDA(embedding_type='weighted', k=None, num_dims=None, " |
31 |
| - "preprocessor=None)") |
| 35 | + self.assertEqual(remove_spaces(str(metric_learn.LFDA())), |
| 36 | + remove_spaces( |
| 37 | + "LFDA(embedding_type='weighted', k=None, " |
| 38 | + "num_dims=None, " |
| 39 | + "preprocessor=None)")) |
32 | 40 |
|
33 | 41 | def test_itml(self):
|
34 |
| - self.assertEqual(str(metric_learn.ITML()), """ |
| 42 | + self.assertEqual(remove_spaces(str(metric_learn.ITML())), |
| 43 | + remove_spaces(""" |
35 | 44 | ITML(A0=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000,
|
36 | 45 | preprocessor=None, verbose=False)
|
37 |
| -""".strip('\n')) |
38 |
| - self.assertEqual(str(metric_learn.ITML_Supervised()), """ |
| 46 | +""")) |
| 47 | + self.assertEqual(remove_spaces(str(metric_learn.ITML_Supervised())), |
| 48 | + remove_spaces(""" |
39 | 49 | ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001,
|
40 | 50 | gamma=1.0, max_iter=1000, num_constraints=None,
|
41 | 51 | num_labeled='deprecated', preprocessor=None, verbose=False)
|
42 |
| -""".strip('\n')) |
| 52 | +""")) |
43 | 53 |
|
44 | 54 | def test_lsml(self):
|
45 | 55 | self.assertEqual(
|
46 |
| - str(metric_learn.LSML()), |
| 56 | + remove_spaces(str(metric_learn.LSML())), |
| 57 | + remove_spaces( |
47 | 58 | "LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, "
|
48 |
| - "verbose=False)") |
49 |
| - self.assertEqual(str(metric_learn.LSML_Supervised()), """ |
| 59 | + "verbose=False)")) |
| 60 | + self.assertEqual(remove_spaces(str(metric_learn.LSML_Supervised())), |
| 61 | + remove_spaces(""" |
50 | 62 | LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated',
|
51 | 63 | preprocessor=None, prior=None, tol=0.001, verbose=False,
|
52 | 64 | weights=None)
|
53 |
| -""".strip('\n')) |
| 65 | +""")) |
54 | 66 |
|
55 | 67 | def test_sdml(self):
|
56 |
| - self.assertEqual(str(metric_learn.SDML()), |
57 |
| - "SDML(balance_param=0.5, preprocessor=None, " |
58 |
| - "sparsity_param=0.01, use_cov=True,\n verbose=False)") |
59 |
| - self.assertEqual(str(metric_learn.SDML_Supervised()), """ |
| 68 | + self.assertEqual(remove_spaces(str(metric_learn.SDML())), |
| 69 | + remove_spaces( |
| 70 | + "SDML(balance_param=0.5, preprocessor=None, " |
| 71 | + "sparsity_param=0.01, use_cov=True," |
| 72 | + "\n verbose=False)")) |
| 73 | + self.assertEqual(remove_spaces(str(metric_learn.SDML_Supervised())), |
| 74 | + remove_spaces(""" |
60 | 75 | SDML_Supervised(balance_param=0.5, num_constraints=None,
|
61 | 76 | num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
|
62 | 77 | use_cov=True, verbose=False)
|
63 |
| -""".strip('\n')) |
| 78 | +""")) |
64 | 79 |
|
65 | 80 | def test_rca(self):
|
66 |
| - self.assertEqual(str(metric_learn.RCA()), |
67 |
| - "RCA(num_dims=None, pca_comps=None, preprocessor=None)") |
68 |
| - self.assertEqual(str(metric_learn.RCA_Supervised()), |
69 |
| - "RCA_Supervised(chunk_size=2, num_chunks=100, " |
70 |
| - "num_dims=None, pca_comps=None,\n " |
71 |
| - "preprocessor=None)") |
| 81 | + self.assertEqual(remove_spaces(str(metric_learn.RCA())), |
| 82 | + remove_spaces("RCA(num_dims=None, pca_comps=None, " |
| 83 | + "preprocessor=None)")) |
| 84 | + self.assertEqual(remove_spaces(str(metric_learn.RCA_Supervised())), |
| 85 | + remove_spaces( |
| 86 | + "RCA_Supervised(chunk_size=2, num_chunks=100, " |
| 87 | + "num_dims=None, pca_comps=None,\n " |
| 88 | + "preprocessor=None)")) |
72 | 89 |
|
73 | 90 | def test_mlkr(self):
|
74 |
| - self.assertEqual(str(metric_learn.MLKR()), |
75 |
| - "MLKR(A0=None, max_iter=1000, num_dims=None, " |
76 |
| - "preprocessor=None, tol=None,\n verbose=False)") |
| 91 | + self.assertEqual(remove_spaces(str(metric_learn.MLKR())), |
| 92 | + remove_spaces( |
| 93 | + "MLKR(A0=None, max_iter=1000, num_dims=None, " |
| 94 | + "preprocessor=None, tol=None,\n verbose=False)")) |
77 | 95 |
|
78 | 96 | def test_mmc(self):
|
79 |
| - self.assertEqual(str(metric_learn.MMC()), """ |
| 97 | + self.assertEqual(remove_spaces(str(metric_learn.MMC())), |
| 98 | + remove_spaces(""" |
80 | 99 | MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0,
|
81 | 100 | max_iter=100, max_proj=10000, preprocessor=None, verbose=False)
|
82 |
| -""".strip('\n')) |
83 |
| - self.assertEqual(str(metric_learn.MMC_Supervised()), """ |
| 101 | +""")) |
| 102 | + self.assertEqual(remove_spaces(str(metric_learn.MMC_Supervised())), |
| 103 | + remove_spaces(""" |
84 | 104 | MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False,
|
85 | 105 | diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None,
|
86 | 106 | num_labeled='deprecated', preprocessor=None, verbose=False)
|
87 |
| -""".strip('\n')) |
| 107 | +""")) |
88 | 108 |
|
89 | 109 |
|
90 | 110 | @pytest.mark.parametrize('estimator, build_dataset', metric_learners,
|
|
0 commit comments