-
Notifications
You must be signed in to change notification settings - Fork 1.3k
ENH duck-typing scikit-learn estimator instead of inheritance #858
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 38 commits
d17b6b5
379ea7e
8790628
e997e23
94b0725
9fbf360
f736879
fcb118e
a4e959c
65ae4fd
5b76d49
8284b70
e97ae36
495ec27
10456f5
93200e1
f104057
b82e4d9
70b6778
c67c775
010f4d5
178d0f0
9868d0f
2e1ee17
8889cfd
5e875a0
9545172
29a414b
12991ba
e24ee06
525002f
189f0e9
2cbe273
cc7fae9
29e4619
a098e84
0aa328e
8cce474
8d4ff31
0ceacfb
ee6b7b0
d089b7b
ac7e00a
d815e2d
964d082
99d5206
48d1fd5
18b6057
76fbd59
8fa97ed
615a2bf
b75b77d
b627cf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
from sklearn.utils import _safe_indexing | ||
|
||
from ..base import BaseOverSampler | ||
from ...exceptions import raise_isinstance_error | ||
from ...utils import check_neighbors_object | ||
from ...utils import Substitution | ||
from ...utils._docstring import _n_jobs_docstring | ||
|
@@ -278,6 +277,8 @@ class SVMSMOTE(BaseSMOTE): | |
|
||
svm_estimator : estimator object, default=SVC() | ||
A parametrized :class:`~sklearn.svm.SVC` classifier can be passed. | ||
A scikit-learn compatible estimator can be passed but it is required | ||
to expose a `support_` fitted attribute. | ||
|
||
out_step : float, default=0.5 | ||
Step size when extrapolating. | ||
|
@@ -385,10 +386,8 @@ def _validate_estimator(self): | |
|
||
if self.svm_estimator is None: | ||
self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state) | ||
elif isinstance(self.svm_estimator, SVC): | ||
self.svm_estimator_ = clone(self.svm_estimator) | ||
else: | ||
raise_isinstance_error("svm_estimator", [SVC], self.svm_estimator) | ||
self.svm_estimator_ = clone(self.svm_estimator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change removes the explicit This will enable the integration of SVM estimators that enforce the same API contract as sklearn instead of requiring the explicit class check ( As a motivating example, the integration of a GPU-accelerated SVC from cuML can offer significant performance gains when working with large datasets. Hardware Specs for the Loose Benchmark: Benchmarking gist: |
||
|
||
def _fit_resample(self, X, y): | ||
self._validate_estimator() | ||
|
@@ -403,6 +402,12 @@ def _fit_resample(self, X, y): | |
X_class = _safe_indexing(X, target_class_indices) | ||
|
||
self.svm_estimator_.fit(X, y) | ||
if not hasattr(self.svm_estimator_, "support_"): | ||
raise RuntimeError( | ||
"`svm_estimator` is required to exposed a `support_` fitted " | ||
"attribute. Such estimator belongs to the familly of Support " | ||
"Vector Machine." | ||
) | ||
support_index = self.svm_estimator_.support_[ | ||
y[self.svm_estimator_.support_] == class_sample | ||
] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,8 @@ class ClusterCentroids(BaseUnderSampler): | |
{random_state} | ||
|
||
estimator : estimator object, default=None | ||
Pass a :class:`~sklearn.cluster.KMeans` estimator. By default, it will | ||
A scikit-learn compatible clustering method that exposes a `n_clusters` | ||
parameter and a `cluster_centers_` fitted attribute. By default, it will | ||
be a default :class:`~sklearn.cluster.KMeans` estimator. | ||
|
||
voting : {{"hard", "soft", "auto"}}, default='auto' | ||
|
@@ -141,13 +142,13 @@ def _validate_estimator(self): | |
) | ||
if self.estimator is None: | ||
self.estimator_ = KMeans(random_state=self.random_state) | ||
elif isinstance(self.estimator, KMeans): | ||
self.estimator_ = clone(self.estimator) | ||
else: | ||
raise ValueError( | ||
f"`estimator` has to be a KMeans clustering." | ||
f" Got {type(self.estimator)} instead." | ||
) | ||
self.estimator_ = clone(self.estimator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change removes the explicit This will enable KMeans estimators that enforce the same API contract as sklearn to be integrated instead of requiring the explicit class check ( As a motivating example, the integration of a GPU-accelerated KMeans estimator from cuML can offer significant performance gains when working with large datasets. Hardware Specs for the Loose Benchmark: Benchmarking gist: |
||
if "n_clusters" not in self.estimator_.get_params(): | ||
raise ValueError( | ||
"`estimator` should be a clustering estimator exposing a parameter" | ||
" `n_clusters` and a fitted parameter `cluster_centers_`." | ||
) | ||
|
||
def _generate_sample(self, X, y, centroids, target_class): | ||
if self.voting_ == "hard": | ||
|
@@ -188,6 +189,11 @@ def _fit_resample(self, X, y): | |
n_samples = self.sampling_strategy_[target_class] | ||
self.estimator_.set_params(**{"n_clusters": n_samples}) | ||
self.estimator_.fit(_safe_indexing(X, target_class_indices)) | ||
if not hasattr(self.estimator_, "cluster_centers_"): | ||
raise RuntimeError( | ||
"`estimator` should be a clustering estimator exposing a " | ||
"fitted parameter `cluster_centers_`." | ||
) | ||
X_new, y_new = self._generate_sample( | ||
_safe_indexing(X, target_class_indices), | ||
_safe_indexing(y, target_class_indices), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanted to highlight this comment #858 (comment)
I believe using cuML here requires 21.12. 21.12 is the current nightly, which is on track for release this week https://docs.rapids.ai/maintainers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, but I am kind of struggling to install
cudatoolkit
withconda
. Locally, the package is available but it seems that this is not the case on Azure Pipeline. I am currently making sure that I have the latest version ofconda
.This said, since there is no GPU on Azure, is cuML falling back on some sort of CPU computing or it will just fail?
In which case, I should investigate to find a free CI service for open source where I can get a GPU.