Skip to content

Commit 05a8d41

Browse files
authored
[MRG] Be compatible with newer scikit-learn (#199)
* Update travis to use previous scikit-learn's versions for older pythons * Update code to work with both versions * Install scikit-learn before skggm * Simpler replacement of spaces and newlines that is compatible with python 2.7 * Address #199 (review) * Address #199 (review)
1 parent d4badc8 commit 05a8d41

File tree

6 files changed

+70
-110
lines changed

6 files changed

+70
-110
lines changed

.travis.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ python:
88
before_install:
99
- sudo apt-get install liblapack-dev
1010
- pip install --upgrade pip pytest
11-
- pip install wheel cython numpy scipy scikit-learn codecov pytest-cov
11+
- pip install wheel cython numpy scipy codecov pytest-cov
12+
- if $TRAVIS_PYTHON_VERSION == "3.6"; then
13+
pip install scikit-learn;
14+
else
15+
pip install scikit-learn==0.20.3;
16+
fi
1217
- if [[ ($TRAVIS_PYTHON_VERSION == "3.6") ||
1318
($TRAVIS_PYTHON_VERSION == "2.7")]]; then
1419
pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8;

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Metric Learning algorithms in Python.
2020
**Dependencies**
2121

2222
- Python 2.7+, 3.4+
23-
- numpy, scipy, scikit-learn
23+
- numpy, scipy, scikit-learn>=0.20.3
2424

2525
**Optional dependencies**
2626

doc/getting_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Alternately, download the source repository and run:
1515
**Dependencies**
1616

1717
- Python 2.7+, 3.4+
18-
- numpy, scipy, scikit-learn
18+
- numpy, scipy, scikit-learn>=0.20.3
1919

2020
**Optional dependencies**
2121

metric_learn/_util.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def check_input(input_data, y=None, preprocessor=None,
2222
dtype='numeric', order=None,
2323
copy=False, force_all_finite=True,
2424
multi_output=False, ensure_min_samples=1,
25-
ensure_min_features=1, y_numeric=False,
26-
warn_on_dtype=False, estimator=None):
25+
ensure_min_features=1, y_numeric=False, estimator=None):
2726
"""Checks that the input format is valid, and converts it if specified
2827
(this is the equivalent of scikit-learn's `check_array` or `check_X_y`).
2928
All arguments following tuple_size are scikit-learn's `check_X_y`
@@ -88,10 +87,6 @@ def check_input(input_data, y=None, preprocessor=None,
8887
is originally 1D and ``ensure_2d`` is True. Setting to 0 disables
8988
this check.
9089
91-
warn_on_dtype : boolean (default=False)
92-
Raise DataConversionWarning if the dtype of the input data structure
93-
does not match the requested dtype, causing a memory copy.
94-
9590
estimator : str or estimator instance (default=`None`)
9691
If passed, include the name of the estimator in warning messages.
9792
@@ -111,7 +106,7 @@ def check_input(input_data, y=None, preprocessor=None,
111106
copy=copy, force_all_finite=force_all_finite,
112107
ensure_min_samples=ensure_min_samples,
113108
ensure_min_features=ensure_min_features,
114-
warn_on_dtype=warn_on_dtype, estimator=estimator)
109+
estimator=estimator)
115110

116111
# We need to convert input_data into a numpy.ndarray if possible, before
117112
# any further checks or conversions, and deal with y if needed. Therefore
@@ -321,9 +316,8 @@ def __init__(self, X):
321316
accept_sparse=True, dtype=None,
322317
force_all_finite=False,
323318
ensure_2d=False, allow_nd=True,
324-
ensure_min_samples=0,
325-
ensure_min_features=0,
326-
warn_on_dtype=False, estimator=None)
319+
ensure_min_samples=0, ensure_min_features=0,
320+
estimator=None)
327321
self.X = X
328322

329323
def __call__(self, indices):

test/test_base_metric.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import re
23
import unittest
34
import metric_learn
45
import numpy as np
@@ -7,84 +8,103 @@
78
from test.test_utils import ids_metric_learners, metric_learners
89

910

11+
def remove_spaces(s):
12+
return re.sub('\s+', '', s)
13+
14+
1015
class TestStringRepr(unittest.TestCase):
1116

1217
def test_covariance(self):
13-
self.assertEqual(str(metric_learn.Covariance()),
14-
"Covariance(preprocessor=None)")
18+
self.assertEqual(remove_spaces(str(metric_learn.Covariance())),
19+
remove_spaces("Covariance(preprocessor=None)"))
1520

1621
def test_lmnn(self):
1722
self.assertRegexpMatches(
18-
str(metric_learn.LMNN()),
19-
r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, "
20-
r"max_iter=1000,\n min_iter=50, preprocessor=None, "
21-
r"regularization=0.5, use_pca=True,\n verbose=False\)")
23+
str(metric_learn.LMNN()),
24+
r"(python_)?LMNN\(convergence_tol=0.001, k=3, learn_rate=1e-07, "
25+
r"max_iter=1000,\s+min_iter=50, preprocessor=None, "
26+
r"regularization=0.5, use_pca=True,\s+verbose=False\)")
2227

2328
def test_nca(self):
24-
self.assertEqual(str(metric_learn.NCA()),
25-
"NCA(max_iter=100, num_dims=None, preprocessor=None, "
26-
"tol=None, verbose=False)")
29+
self.assertEqual(remove_spaces(str(metric_learn.NCA())),
30+
remove_spaces(
31+
"NCA(max_iter=100, num_dims=None, preprocessor=None, "
32+
"tol=None, verbose=False)"))
2733

2834
def test_lfda(self):
29-
self.assertEqual(str(metric_learn.LFDA()),
30-
"LFDA(embedding_type='weighted', k=None, num_dims=None, "
31-
"preprocessor=None)")
35+
self.assertEqual(remove_spaces(str(metric_learn.LFDA())),
36+
remove_spaces(
37+
"LFDA(embedding_type='weighted', k=None, "
38+
"num_dims=None, "
39+
"preprocessor=None)"))
3240

3341
def test_itml(self):
34-
self.assertEqual(str(metric_learn.ITML()), """
42+
self.assertEqual(remove_spaces(str(metric_learn.ITML())),
43+
remove_spaces("""
3544
ITML(A0=None, convergence_threshold=0.001, gamma=1.0, max_iter=1000,
3645
preprocessor=None, verbose=False)
37-
""".strip('\n'))
38-
self.assertEqual(str(metric_learn.ITML_Supervised()), """
46+
"""))
47+
self.assertEqual(remove_spaces(str(metric_learn.ITML_Supervised())),
48+
remove_spaces("""
3949
ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001,
4050
gamma=1.0, max_iter=1000, num_constraints=None,
4151
num_labeled='deprecated', preprocessor=None, verbose=False)
42-
""".strip('\n'))
52+
"""))
4353

4454
def test_lsml(self):
4555
self.assertEqual(
46-
str(metric_learn.LSML()),
56+
remove_spaces(str(metric_learn.LSML())),
57+
remove_spaces(
4758
"LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, "
48-
"verbose=False)")
49-
self.assertEqual(str(metric_learn.LSML_Supervised()), """
59+
"verbose=False)"))
60+
self.assertEqual(remove_spaces(str(metric_learn.LSML_Supervised())),
61+
remove_spaces("""
5062
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated',
5163
preprocessor=None, prior=None, tol=0.001, verbose=False,
5264
weights=None)
53-
""".strip('\n'))
65+
"""))
5466

5567
def test_sdml(self):
56-
self.assertEqual(str(metric_learn.SDML()),
57-
"SDML(balance_param=0.5, preprocessor=None, "
58-
"sparsity_param=0.01, use_cov=True,\n verbose=False)")
59-
self.assertEqual(str(metric_learn.SDML_Supervised()), """
68+
self.assertEqual(remove_spaces(str(metric_learn.SDML())),
69+
remove_spaces(
70+
"SDML(balance_param=0.5, preprocessor=None, "
71+
"sparsity_param=0.01, use_cov=True,"
72+
"\n verbose=False)"))
73+
self.assertEqual(remove_spaces(str(metric_learn.SDML_Supervised())),
74+
remove_spaces("""
6075
SDML_Supervised(balance_param=0.5, num_constraints=None,
6176
num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
6277
use_cov=True, verbose=False)
63-
""".strip('\n'))
78+
"""))
6479

6580
def test_rca(self):
66-
self.assertEqual(str(metric_learn.RCA()),
67-
"RCA(num_dims=None, pca_comps=None, preprocessor=None)")
68-
self.assertEqual(str(metric_learn.RCA_Supervised()),
69-
"RCA_Supervised(chunk_size=2, num_chunks=100, "
70-
"num_dims=None, pca_comps=None,\n "
71-
"preprocessor=None)")
81+
self.assertEqual(remove_spaces(str(metric_learn.RCA())),
82+
remove_spaces("RCA(num_dims=None, pca_comps=None, "
83+
"preprocessor=None)"))
84+
self.assertEqual(remove_spaces(str(metric_learn.RCA_Supervised())),
85+
remove_spaces(
86+
"RCA_Supervised(chunk_size=2, num_chunks=100, "
87+
"num_dims=None, pca_comps=None,\n "
88+
"preprocessor=None)"))
7289

7390
def test_mlkr(self):
74-
self.assertEqual(str(metric_learn.MLKR()),
75-
"MLKR(A0=None, max_iter=1000, num_dims=None, "
76-
"preprocessor=None, tol=None,\n verbose=False)")
91+
self.assertEqual(remove_spaces(str(metric_learn.MLKR())),
92+
remove_spaces(
93+
"MLKR(A0=None, max_iter=1000, num_dims=None, "
94+
"preprocessor=None, tol=None,\n verbose=False)"))
7795

7896
def test_mmc(self):
79-
self.assertEqual(str(metric_learn.MMC()), """
97+
self.assertEqual(remove_spaces(str(metric_learn.MMC())),
98+
remove_spaces("""
8099
MMC(A0=None, convergence_threshold=0.001, diagonal=False, diagonal_c=1.0,
81100
max_iter=100, max_proj=10000, preprocessor=None, verbose=False)
82-
""".strip('\n'))
83-
self.assertEqual(str(metric_learn.MMC_Supervised()), """
101+
"""))
102+
self.assertEqual(remove_spaces(str(metric_learn.MMC_Supervised())),
103+
remove_spaces("""
84104
MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False,
85105
diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None,
86106
num_labeled='deprecated', preprocessor=None, verbose=False)
87-
""".strip('\n'))
107+
"""))
88108

89109

90110
@pytest.mark.parametrize('estimator, build_dataset', metric_learners,

test/test_utils.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -300,35 +300,6 @@ def test_check_tuples_invalid_n_samples(estimator, context, load_tuples,
300300
assert str(raised_error.value) == msg
301301

302302

303-
@pytest.mark.parametrize('estimator, context',
304-
[(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")])
305-
@pytest.mark.parametrize('load_tuples, preprocessor',
306-
[(tuples_prep, mock_preprocessor),
307-
(tuples_no_prep, None),
308-
(tuples_no_prep, mock_preprocessor)])
309-
def test_check_tuples_invalid_dtype_convertible(estimator, context,
310-
load_tuples, preprocessor):
311-
"""Checks that a warning is raised if a convertible input is converted to
312-
float"""
313-
tuples = load_tuples().astype(object) # here the object conversion is
314-
# useless for the tuples_prep case, but this allows to test the
315-
# tuples_prep case
316-
317-
if preprocessor is not None: # if the preprocessor is not None we
318-
# overwrite it to have a preprocessor that returns objects
319-
def preprocessor(indices): #
320-
# preprocessor that returns objects
321-
return np.ones((indices.shape[0], 3)).astype(object)
322-
323-
msg = ("Data with input dtype object was converted to float64{}."
324-
.format(context))
325-
with pytest.warns(DataConversionWarning) as raised_warning:
326-
check_input(tuples, type_of_inputs='tuples',
327-
preprocessor=preprocessor, dtype=np.float64,
328-
warn_on_dtype=True, estimator=estimator)
329-
assert str(raised_warning[0].message) == msg
330-
331-
332303
def test_check_tuples_invalid_dtype_not_convertible_with_preprocessor():
333304
"""Checks that a value error is thrown if attempting to convert an
334305
input not convertible to float, when using a preprocessor
@@ -530,36 +501,6 @@ def test_check_classic_invalid_n_samples(estimator, context, load_points,
530501
assert str(raised_error.value) == msg
531502

532503

533-
@pytest.mark.parametrize('estimator, context',
534-
[(NCA(), " by NCA"), ('NCA', " by NCA"), (None, "")])
535-
@pytest.mark.parametrize('load_points, preprocessor',
536-
[(points_prep, mock_preprocessor),
537-
(points_no_prep, None),
538-
(points_no_prep, mock_preprocessor)])
539-
def test_check_classic_invalid_dtype_convertible(estimator, context,
540-
load_points,
541-
preprocessor):
542-
"""Checks that a warning is raised if a convertible input is converted to
543-
float"""
544-
points = load_points().astype(object) # here the object conversion is
545-
# useless for the points_prep case, but this allows to test the
546-
# points_prep case
547-
548-
if preprocessor is not None: # if the preprocessor is not None we
549-
# overwrite it to have a preprocessor that returns objects
550-
def preprocessor(indices):
551-
# preprocessor that returns objects
552-
return np.ones((indices.shape[0], 3)).astype(object)
553-
554-
msg = ("Data with input dtype object was converted to float64{}."
555-
.format(context))
556-
with pytest.warns(DataConversionWarning) as raised_warning:
557-
check_input(points, type_of_inputs='classic',
558-
preprocessor=preprocessor, dtype=np.float64,
559-
warn_on_dtype=True, estimator=estimator)
560-
assert str(raised_warning[0].message) == msg
561-
562-
563504
@pytest.mark.parametrize('preprocessor, points',
564505
[(mock_preprocessor, np.array([['a', 'b'],
565506
['e', 'b']])),

0 commit comments

Comments
 (0)