Skip to content

Commit 5064c7e

Browse files
committed
Merge remote-tracking branch 'origin/master' into is/428
2 parents 0de85a7 + cc78dde commit 5064c7e

File tree

11 files changed

+27
-15
lines changed

11 files changed

+27
-15
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ Bug fixes
6060
and thus to obtain a deterministic results when using the same random state.
6161
:issue:`447` by :user:`Guillaume Lemaitre <glemaitre>`.
6262

63+
- Force to clone scikit-learn estimator passed as attributes to samplers.
64+
:issue:`446` by :user:`Guillaume Lemaitre <glemaitre>`.
65+
6366
Maintenance
6467
...........
6568

imblearn/combine/smote_enn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
import warnings
1111

12+
from sklearn.base import clone
1213
from sklearn.utils import check_X_y
1314

1415
from ..base import SamplerMixin
@@ -103,7 +104,7 @@ def _validate_estimator(self):
103104
"Private function to validate SMOTE and ENN objects"
104105
if self.smote is not None:
105106
if isinstance(self.smote, SMOTE):
106-
self.smote_ = self.smote
107+
self.smote_ = clone(self.smote)
107108
else:
108109
raise ValueError('smote needs to be a SMOTE object.'
109110
'Got {} instead.'.format(type(self.smote)))
@@ -116,7 +117,7 @@ def _validate_estimator(self):
116117

117118
if self.enn is not None:
118119
if isinstance(self.enn, EditedNearestNeighbours):
119-
self.enn_ = self.enn
120+
self.enn_ = clone(self.enn)
120121
else:
121122
raise ValueError('enn needs to be an EditedNearestNeighbours.'
122123
' Got {} instead.'.format(type(self.enn)))

imblearn/combine/smote_tomek.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import warnings
1212

13+
from sklearn.base import clone
1314
from sklearn.utils import check_X_y
1415

1516
from ..base import SamplerMixin
@@ -111,7 +112,7 @@ def _validate_estimator(self):
111112

112113
if self.smote is not None:
113114
if isinstance(self.smote, SMOTE):
114-
self.smote_ = self.smote
115+
self.smote_ = clone(self.smote)
115116
else:
116117
raise ValueError('smote needs to be a SMOTE object.'
117118
'Got {} instead.'.format(type(self.smote)))
@@ -124,7 +125,7 @@ def _validate_estimator(self):
124125

125126
if self.tomek is not None:
126127
if isinstance(self.tomek, TomekLinks):
127-
self.tomek_ = self.tomek
128+
self.tomek_ = clone(self.tomek)
128129
else:
129130
raise ValueError('tomek needs to be a TomekLinks object.'
130131
'Got {} instead.'.format(type(self.tomek)))

imblearn/ensemble/balance_cascade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from sklearn.base import ClassifierMixin
11+
from sklearn.base import ClassifierMixin, clone
1212
from sklearn.neighbors import KNeighborsClassifier
1313
from sklearn.utils import check_random_state, safe_indexing
1414
from sklearn.model_selection import cross_val_predict
@@ -142,7 +142,7 @@ def _validate_estimator(self):
142142
if (self.estimator is not None and
143143
isinstance(self.estimator, ClassifierMixin) and
144144
hasattr(self.estimator, 'predict')):
145-
self.estimator_ = self.estimator
145+
self.estimator_ = clone(self.estimator)
146146
elif self.estimator is None:
147147
self.estimator_ = KNeighborsClassifier()
148148
else:

imblearn/over_sampling/smote.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from scipy import sparse
1616

17+
from sklearn.base import clone
1718
from sklearn.svm import SVC
1819
from sklearn.utils import check_random_state, safe_indexing
1920

@@ -448,7 +449,7 @@ def _validate_estimator(self):
448449
if self.svm_estimator is None:
449450
self.svm_estimator_ = SVC(random_state=self.random_state)
450451
elif isinstance(self.svm_estimator, SVC):
451-
self.svm_estimator_ = self.svm_estimator
452+
self.svm_estimator_ = clone(self.svm_estimator)
452453
else:
453454
raise_isinstance_error('svm_estimator', [SVC],
454455
self.svm_estimator)
@@ -698,7 +699,7 @@ def _validate_estimator(self):
698699
self.svm_estimator == 'deprecated'):
699700
self.svm_estimator_ = SVC(random_state=self.random_state)
700701
elif isinstance(self.svm_estimator, SVC):
701-
self.svm_estimator_ = self.svm_estimator
702+
self.svm_estimator_ = clone(self.svm_estimator)
702703
else:
703704
raise_isinstance_error('svm_estimator', [SVC],
704705
self.svm_estimator)

imblearn/under_sampling/prototype_generation/cluster_centroids.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
from scipy import sparse
1313

14+
from sklearn.base import clone
1415
from sklearn.cluster import KMeans
1516
from sklearn.neighbors import NearestNeighbors
1617
from sklearn.utils import safe_indexing
@@ -113,7 +114,7 @@ def _validate_estimator(self):
113114
self.estimator_ = KMeans(
114115
random_state=self.random_state, n_jobs=self.n_jobs)
115116
elif isinstance(self.estimator, KMeans):
116-
self.estimator_ = self.estimator
117+
self.estimator_ = clone(self.estimator)
117118
else:
118119
raise ValueError('`estimator` has to be a KMeans clustering.'
119120
' Got {} instead.'.format(type(self.estimator)))

imblearn/under_sampling/prototype_selection/condensed_nearest_neighbour.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from scipy.sparse import issparse
1515

16+
from sklearn.base import clone
1617
from sklearn.neighbors import KNeighborsClassifier
1718
from sklearn.utils import check_random_state, safe_indexing
1819

@@ -121,7 +122,7 @@ def _validate_estimator(self):
121122
self.estimator_ = KNeighborsClassifier(
122123
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs)
123124
elif isinstance(self.n_neighbors, KNeighborsClassifier):
124-
self.estimator_ = self.n_neighbors
125+
self.estimator_ = clone(self.n_neighbors)
125126
else:
126127
raise ValueError('`n_neighbors` has to be a int or an object'
127128
' inhereited from KNeighborsClassifier.'

imblearn/under_sampling/prototype_selection/instance_hardness_threshold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414

15-
from sklearn.base import ClassifierMixin
15+
from sklearn.base import ClassifierMixin, clone
1616
from sklearn.ensemble import RandomForestClassifier
1717
from sklearn.model_selection import StratifiedKFold
1818
from sklearn.utils import safe_indexing
@@ -117,7 +117,7 @@ def _validate_estimator(self):
117117
if (self.estimator is not None and
118118
isinstance(self.estimator, ClassifierMixin) and
119119
hasattr(self.estimator, 'predict_proba')):
120-
self.estimator_ = self.estimator
120+
self.estimator_ = clone(self.estimator)
121121
elif self.estimator is None:
122122
self.estimator_ = RandomForestClassifier(
123123
random_state=self.random_state, n_jobs=self.n_jobs)

imblearn/under_sampling/prototype_selection/one_sided_selection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from collections import Counter
1010

1111
import numpy as np
12+
13+
from sklearn.base import clone
1214
from sklearn.neighbors import KNeighborsClassifier
1315
from sklearn.utils import check_random_state, safe_indexing
1416

@@ -114,7 +116,7 @@ def _validate_estimator(self):
114116
self.estimator_ = KNeighborsClassifier(
115117
n_neighbors=self.n_neighbors, n_jobs=self.n_jobs)
116118
elif isinstance(self.n_neighbors, KNeighborsClassifier):
117-
self.estimator_ = self.n_neighbors
119+
self.estimator_ = clone(self.n_neighbors)
118120
else:
119121
raise ValueError('`n_neighbors` has to be a int or an object'
120122
' inhereited from KNeighborsClassifier.'

imblearn/utils/tests/test_validation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_check_neighbors_object():
3636
assert issubclass(type(estimator), KNeighborsMixin)
3737
assert estimator.n_neighbors == 2
3838
estimator = NearestNeighbors(n_neighbors)
39-
assert estimator is check_neighbors_object(name, estimator)
39+
estimator_cloned = check_neighbors_object(name, estimator)
40+
assert estimator.n_neighbors == estimator_cloned.n_neighbors
4041
n_neighbors = 'rnd'
4142
with pytest.raises(ValueError, match="has to be one of"):
4243
check_neighbors_object(name, n_neighbors)

imblearn/utils/validation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import numpy as np
1313

14+
from sklearn.base import clone
1415
from sklearn.neighbors.base import KNeighborsMixin
1516
from sklearn.neighbors import NearestNeighbors
1617
from sklearn.externals import six, joblib
@@ -51,7 +52,7 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
5152
if isinstance(nn_object, Integral):
5253
return NearestNeighbors(n_neighbors=nn_object + additional_neighbor)
5354
elif isinstance(nn_object, KNeighborsMixin):
54-
return nn_object
55+
return clone(nn_object)
5556
else:
5657
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)
5758

0 commit comments

Comments
 (0)