-
Notifications
You must be signed in to change notification settings - Fork 1.3k
ENH duck-typing scikit-learn estimator instead of inheritance #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
d17b6b5
379ea7e
8790628
e997e23
94b0725
9fbf360
f736879
fcb118e
a4e959c
65ae4fd
5b76d49
8284b70
e97ae36
495ec27
10456f5
93200e1
f104057
b82e4d9
70b6778
c67c775
010f4d5
178d0f0
9868d0f
2e1ee17
8889cfd
5e875a0
9545172
29a414b
12991ba
e24ee06
525002f
189f0e9
2cbe273
cc7fae9
29e4619
a098e84
0aa328e
8cce474
8d4ff31
0ceacfb
ee6b7b0
d089b7b
ac7e00a
d815e2d
964d082
99d5206
48d1fd5
18b6057
76fbd59
8fa97ed
615a2bf
b75b77d
b627cf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,13 +123,8 @@ def _validate_estimator(self): | |
) | ||
if self.estimator is None: | ||
self.estimator_ = KMeans(random_state=self.random_state) | ||
elif isinstance(self.estimator, KMeans): | ||
self.estimator_ = clone(self.estimator) | ||
else: | ||
raise ValueError( | ||
f"`estimator` has to be a KMeans clustering." | ||
f" Got {type(self.estimator)} instead." | ||
) | ||
self.estimator_ = clone(self.estimator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change removes the explicit This will enable KMeans estimators that enforce the same API contract as sklearn to be integrated instead of requiring the explicit class check ( As a motivating example, the integration of a GPU-accelerated KMeans estimator from cuML can offer significant performance gains when working with large datasets. Hardware Specs for the Loose Benchmark: Benchmarking gist: |
||
|
||
def _generate_sample(self, X, y, centroids, target_class): | ||
if self.voting_ == "hard": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,14 @@ def _transfrom_one(self, array, props): | |
return ret | ||
|
||
|
||
def _is_neighbors_object(kneighbors_estimator): | ||
neighbors_attributes = [ | ||
"kneighbors", | ||
"kneighbors_graph" | ||
] | ||
return all(hasattr(kneighbors_estimator, attr) for attr in neighbors_attributes) | ||
|
||
|
||
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): | ||
"""Check the objects is consistent to be a NN. | ||
|
||
|
@@ -93,7 +101,7 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): | |
""" | ||
if isinstance(nn_object, Integral): | ||
return NearestNeighbors(n_neighbors=nn_object + additional_neighbor) | ||
elif isinstance(nn_object, KNeighborsMixin): | ||
elif _is_neighbors_object(nn_object): | ||
return clone(nn_object) | ||
else: | ||
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we should as well change the error message since we don't strictly require to be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have been debating between two implementations here (my two latest commits [1 2]). [1] uses sklearn.base.clone to verify that the nn_object is an sklearn-like estimator that can be cloned. This implementation is more consistent with how the library checks the integrity of other estimators - such as the KMeans Estimator check in [2] raises a TypeError if the nn_object is neither an integer, nor exposes both Do you prefer one over the other? |
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change removes the explicit
isinstance
check for validating the SVC estimator in SVMSMOTE's_validate_estimator
method; the estimator is instead validated by way ofsklearn.base.clone()
, similar to that ofKMeansSMOTE
.This will enable the integration of SVM estimators that enforce the same API contract as sklearn instead of requiring the explicit class check (
isinstance(svm_estimator, sklearn.svm.SVC)
)As a motivating example, the integration of a GPU-accelerated SVC from cuML can offer significant performance gains when working with large datasets.
Hardware Specs for the Loose Benchmark:
Intel Xeon E5-2698, 2.2 GHz, 16-cores & NVIDIA V100 32 GB GPU
Benchmarking gist:
https://gist.github.com/NV-jpt/039a8d9c7d37365379faa1d7c7aafc5e