9
9
import pytest
10
10
import numpy as np
11
11
12
- from sklearn .base import BaseEstimator
13
12
from sklearn .neighbors ._base import KNeighborsMixin
14
13
from sklearn .neighbors import NearestNeighbors
15
14
from sklearn .utils ._testing import assert_array_equal
16
15
17
- from imblearn .utils .testing import warns
18
16
from imblearn .utils import check_neighbors_object
19
17
from imblearn .utils import check_sampling_strategy
20
18
from imblearn .utils import check_target_type
19
+ from imblearn .utils .testing import warns , _CustomNearestNeighbors
21
20
from imblearn .utils ._validation import ArraysTransformer
22
21
from imblearn .utils ._validation import _deprecate_positional_args
23
22
24
23
multiclass_target = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
25
24
binary_target = np .array ([1 ] * 25 + [0 ] * 100 )
26
25
27
26
28
- class KNNLikeEstimator (BaseEstimator ):
29
- """A class exposing the same KNeighborsMixin API than KNeighborsClassifier."""
30
-
31
- def kneighbors (self , X ):
32
- return np .ones ((len (X ), 1 ))
33
-
34
- def kneighbors_graph (self , X ):
35
- return np .ones ((len (X ), 1 ))
36
-
37
-
38
27
def test_check_neighbors_object ():
39
28
name = "n_neighbors"
40
29
n_neighbors = 1
@@ -47,9 +36,9 @@ def test_check_neighbors_object():
47
36
estimator = NearestNeighbors (n_neighbors = n_neighbors )
48
37
estimator_cloned = check_neighbors_object (name , estimator )
49
38
assert estimator .n_neighbors == estimator_cloned .n_neighbors
50
- estimator = KNNLikeEstimator ()
39
+ estimator = _CustomNearestNeighbors ()
51
40
estimator_cloned = check_neighbors_object (name , estimator )
52
- assert isinstance (estimator_cloned , KNNLikeEstimator )
41
+ assert isinstance (estimator_cloned , _CustomNearestNeighbors )
53
42
n_neighbors = "rnd"
54
43
err_msg = (
55
44
"n_neighbors must be an interger or an object compatible with the "
0 commit comments