Skip to content

Commit 8790628

Browse files
sft-managedsft-managed
sft-managed
authored and
sft-managed
committed
Added _is_neighbors_object() private validation function
1 parent 379ea7e commit 8790628

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

imblearn/utils/_validation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def _transfrom_one(self, array, props):
6666
ret = array
6767
return ret
6868

69+
def _is_neighbors_object(kneighbors_estimator):
70+
neighbors_attributes = [
71+
"kneighbors",
72+
"kneighbors_graph"
73+
]
74+
return all(hasattr(kneighbors_estimator, attr) for attr in neighbors_attributes)
6975

7076
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
7177
"""Check the objects is consistent to be a NN.
@@ -93,7 +99,7 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
9399
"""
94100
if isinstance(nn_object, Integral):
95101
return NearestNeighbors(n_neighbors=nn_object + additional_neighbor)
96-
elif hasattr(nn_object, 'kneighbors') and hasattr(nn_object, 'kneighbors_graph'):
102+
elif _is_neighbors_object(nn_object):
97103
return clone(nn_object)
98104
else:
99105
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)

0 commit comments

Comments
 (0)