Skip to content

Commit f2a572f

Browse files
committed
iter
1 parent e54c772 commit f2a572f

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

imblearn/utils/_validation.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
134134

135135

136136
def get_classes_counts(y):
137+
"""Compute the counts of each class present in `y`.
138+
139+
Parameters
140+
----------
141+
y : ndarray of shape (n_samples,)
142+
The target array.
143+
144+
Returns
145+
-------
146+
classes_counts : dict
147+
A dictionary where the keys are the class labels and the values are the
148+
counts for each class.
149+
"""
137150
unique, counts = np.unique(y, return_counts=True)
138151
if is_dask_collection(unique):
139152
from dask import compute
@@ -542,8 +555,14 @@ def check_sampling_strategy(
542555
correspond to the targeted classes. The values correspond to the
543556
desired number of samples for each class.
544557
545-
y : ndarray of shape (n_samples,)
546-
The target array.
558+
classes_counts : dict or ndarray of shape (n_samples,)
559+
A dictionary where the keys are the class present in `y` and the values
560+
are the counts. The function :func:`~imblearn.utils.get_classes_count`
561+
provides such a dictionary, giving `y` as an input.
562+
563+
.. deprecated:: 0.7
564+
Passing the array `y` is deprecated from 0.7 and will be removed
565+
in 0.9.
547566
548567
sampling_type : {{'over-sampling', 'under-sampling', 'clean-sampling'}}
549568
The type of sampling. Can be either ``'over-sampling'``,
@@ -567,6 +586,15 @@ def check_sampling_strategy(
567586
" instead.".format(SAMPLING_KIND, sampling_type)
568587
)
569588

589+
if hasattr(y, "__array__"):
590+
warnings.warn(
591+
f"Passing that array of target `y` is deprecated in 0.7 and will "
592+
f"raise an error from 0.9. Instead, pass `y` to "
593+
"imblearn.utils.get_classes_counts function to get the "
594+
"dictionary.", FutureWarning
595+
)
596+
classes_counts = get_classes_counts(classes_counts)
597+
570598
if len(classes_counts) <= 1:
571599
raise ValueError(
572600
"The target 'y' needs to have more than 1 class."

0 commit comments

Comments
 (0)