Skip to content

Commit cddf39b

Browse files
authored
[MRG] EHN handling sparse matrices whenever possible (#316)
* EHN POC sparse handling for RandomUnderSampler * EHN support sparse ENN * iter * EHN sparse indexing IHT * EHN sparse support nearmiss * EHN support sparse matrices for NCR * EHN support sparse Tomek and OSS * EHN support sparsity for CNN * EHN support sparse for SMOTE * EHN support sparse adasyn * EHN support sparsity for sombine methods * EHN support sparsity BC * DOC update docstring * DOC fix example topic classification * FIX fix test and class clustercentroids * TST add common test * TST add ensemble * TST use allclose * TST install conda with ubuntu container * TST increase tolerance * TST increase tolerance * TST test all versions NearMiss and SMOTE * TST set the algorithm of KMeans * DOC add entry in user guide * DOC add entry sparse for CC * DOC whatsnew entry * DOC fix api * TST adapt pytest * DOC update user guide * address comments * TST remove the last assert_regex
1 parent 488a0e8 commit cddf39b

33 files changed

+682
-550
lines changed

appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ install:
3636
- "python -c \"import struct; print(struct.calcsize('P') * 8)\""
3737

3838
# Installed prebuilt dependencies from conda
39-
- "conda install pip numpy scipy scikit-learn=0.19.0 nose wheel matplotlib -y -q"
39+
- "conda install pip numpy scipy scikit-learn=0.19.0 pandas nose wheel matplotlib -y -q"
4040

4141
# Install other nilearn dependencies
4242
- "pip install coverage nose-timer pytest pytest-cov"

build_tools/travis/install.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ if [[ "$DISTRIB" == "conda" ]]; then
3838
# provided versions
3939
conda create -n testenv --yes python=$PYTHON_VERSION pip
4040
source activate testenv
41-
conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION
41+
conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION pandas
4242

4343
if [[ "$SKLEARN_VERSION" == "master" ]]; then
4444
conda install --yes cython
@@ -59,7 +59,7 @@ elif [[ "$DISTRIB" == "ubuntu" ]]; then
5959
# Create a new virtualenv using system site packages for python, numpy
6060
virtualenv --system-site-packages testvenv
6161
source testvenv/bin/activate
62-
pip install scikit-learn nose nose-timer pytest pytest-cov codecov
62+
pip install scikit-learn pandas nose nose-timer pytest pytest-cov codecov
6363

6464
fi
6565

doc/combine.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,18 @@ than their former samplers::
2929
... n_clusters_per_class=1,
3030
... weights=[0.01, 0.05, 0.94],
3131
... class_sep=0.8, random_state=0)
32-
>>> print(Counter(y))
33-
Counter({2: 4674, 1: 262, 0: 64})
32+
>>> print(sorted(Counter(y).items()))
33+
[(0, 64), (1, 262), (2, 4674)]
3434
>>> from imblearn.combine import SMOTEENN
3535
>>> smote_enn = SMOTEENN(random_state=0)
3636
>>> X_resampled, y_resampled = smote_enn.fit_sample(X, y)
37-
>>> print(Counter(y_resampled))
38-
Counter({1: 4381, 0: 4060, 2: 3502})
37+
>>> print(sorted(Counter(y_resampled).items()))
38+
[(0, 4060), (1, 4381), (2, 3502)]
3939
>>> from imblearn.combine import SMOTETomek
4040
>>> smote_tomek = SMOTETomek(random_state=0)
4141
>>> X_resampled, y_resampled = smote_tomek.fit_sample(X, y)
42-
>>> print(Counter(y_resampled))
43-
Counter({1: 4566, 0: 4499, 2: 4413})
42+
>>> print(sorted(Counter(y_resampled).items()))
43+
[(0, 4499), (1, 4566), (2, 4413)]
4444

4545
We can also see in the example below that :class:`SMOTEENN` tends to clean more
4646
noisy samples than :class:`SMOTETomek`.

doc/datasets/index.rst

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ A specific data set can be selected as::
8585
>>> ecoli = fetch_datasets()['ecoli']
8686
>>> ecoli.data.shape
8787
(336, 7)
88-
>>> print(Counter((ecoli.target)))
89-
Counter({-1: 301, 1: 35})
88+
>>> print(sorted(Counter(ecoli.target).items()))
89+
[(-1, 301), (1, 35)]
9090

9191
.. _make_imbalanced:
9292

@@ -104,16 +104,16 @@ samples in the class::
104104
>>> iris = load_iris()
105105
>>> ratio = {0: 20, 1: 30, 2: 40}
106106
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target, ratio=ratio)
107-
>>> Counter(y_imb)
108-
Counter({2: 40, 1: 30, 0: 20})
107+
>>> sorted(Counter(y_imb).items())
108+
[(0, 20), (1, 30), (2, 40)]
109109

110110
Note that all samples of a class are passed-through if the class is not mentioned
111111
in the dictionary::
112112

113113
>>> ratio = {0: 10}
114114
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target, ratio=ratio)
115-
>>> Counter(y_imb)
116-
Counter({1: 50, 2: 50, 0: 10})
115+
>>> sorted(Counter(y_imb).items())
116+
[(0, 10), (1, 50), (2, 50)]
117117

118118
Instead of a dictionary, a function can be defined and directly pass to
119119
``ratio``::
@@ -126,9 +126,8 @@ Instead of a dictionary, a function can be defined and directly pass to
126126
... return target_stats
127127
>>> X_imb, y_imb = make_imbalance(iris.data, iris.target,
128128
... ratio=ratio_multiplier)
129-
>>> Counter(y_imb)
130-
Counter({2: 47, 1: 35, 0: 25})
131-
129+
>>> sorted(Counter(y_imb).items())
130+
[(0, 25), (1, 35), (2, 47)]
132131

133132
See :ref:`sphx_glr_auto_examples_datasets_plot_make_imbalance.py` and
134133
:ref:`sphx_glr_auto_examples_plot_ratio_usage.py`.

doc/ensemble.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ under-sampling the original set::
1919
... n_clusters_per_class=1,
2020
... weights=[0.01, 0.05, 0.94],
2121
... class_sep=0.8, random_state=0)
22-
>>> print(Counter(y))
23-
Counter({2: 4674, 1: 262, 0: 64})
22+
>>> print(sorted(Counter(y).items()))
23+
[(0, 64), (1, 262), (2, 4674)]
2424
>>> from imblearn.ensemble import EasyEnsemble
2525
>>> ee = EasyEnsemble(random_state=0, n_subsets=10)
2626
>>> X_resampled, y_resampled = ee.fit_sample(X, y)
2727
>>> print(X_resampled.shape)
2828
(10, 192, 2)
29-
>>> print(Counter(y_resampled[0])) # doctest: +SKIP
30-
Counter({0: 64, 1: 64, 2: 64})
29+
>>> print(sorted(Counter(y_resampled[0]).items()))
30+
[(0, 64), (1, 64), (2, 64)]
3131

3232
:class:`EasyEnsemble` has two important parameters: (i) ``n_subsets`` will be
3333
used to return number of subset and (ii) ``replacement`` to randomly sample
@@ -48,8 +48,8 @@ parameter ``n_max_subset`` and an additional bootstraping can be activated with
4848
>>> X_resampled, y_resampled = bc.fit_sample(X, y)
4949
>>> print(X_resampled.shape)
5050
(4, 192, 2)
51-
>>> print(Counter(y_resampled[0])) # doctest: +SKIP
52-
Counter({2: 64, 1: 64, 0: 64})
51+
>>> print(sorted(Counter(y_resampled[0]).items()))
52+
[(0, 64), (1, 64), (2, 64)]
5353

5454
See
5555
:ref:`sphx_glr_auto_examples_ensemble_plot_easy_ensemble.py` and

doc/introduction.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
.. _introduction:
2+
3+
============
4+
Introduction
5+
============
6+
7+
.. _api_imblearn:
8+
9+
API's of imbalanced-learn samplers
10+
----------------------------------
11+
12+
The available samplers follows the scikit-learn API using the base estimator and adding a sampling functionality throw the ``sample`` method::
13+
14+
:Estimator:
15+
16+
The base object, implements a ``fit`` method to learn from data, either::
17+
18+
estimator = obj.fit(data, targets)
19+
20+
:Sampler:
21+
22+
To resample a data sets, each sampler implements::
23+
24+
data_resampled, targets_resampled = obj.sample(data, targets)
25+
26+
Fitting and sampling can also be done in one step::
27+
28+
data_resampled, targets_resampled = obj.fit_sample(data, targets)
29+
30+
Imbalanced-learn samplers accept the same inputs that in scikit-learn:
31+
32+
* ``data``: array-like (2-D list, pandas.Dataframe, numpy.array) or sparse
33+
matrices;
34+
* ``targets``: array-like (1-D list, pandas.Series, numpy.array).
35+
36+
.. topic:: Sparse input
37+
38+
For sparse input the data is **converted to the Compressed Sparse Rows
39+
representation** (see ``scipy.sparse.csr_matrix``) before being fed to the
40+
sampler. To avoid unnecessary memory copies, it is recommended to choose the
41+
CSR representation upstream.
42+
43+
.. _problem_statement:
44+
45+
Problem statement regarding imbalanced data sets
46+
------------------------------------------------
47+
48+
The learning phase and the subsequent prediction of machine learning algorithms
49+
can be affected by the problem of imbalanced data set. The balancing issue
50+
corresponds to the difference of the number of samples in the different
51+
classes. We illustrate the effect of training a linear SVM classifier with
52+
different level of class balancing.
53+
54+
.. image:: ./auto_examples/over-sampling/images/sphx_glr_plot_comparison_over_sampling_001.png
55+
:target: ./auto_examples/over-sampling/plot_comparison_over_sampling.html
56+
:scale: 60
57+
:align: center
58+
59+
As expected, the decision function of the linear SVM is highly impacted. With a
60+
greater imbalanced ratio, the decision function favor the class with the larger
61+
number of samples, usually referred as the majority class.

doc/over_sampling.rst

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ randomly sampling with replacement the current available samples. The
2929
>>> ros = RandomOverSampler(random_state=0)
3030
>>> X_resampled, y_resampled = ros.fit_sample(X, y)
3131
>>> from collections import Counter
32-
>>> print(Counter(y_resampled)) # doctest: +SKIP
33-
Counter({2: 4674, 1: 4674, 0: 4674})
32+
>>> print(sorted(Counter(y_resampled).items()))
33+
[(0, 4674), (1, 4674), (2, 4674)]
3434

3535
The augmented data set should be used instead of the original data set to train
3636
a classifier::
3737

3838
>>> from sklearn.svm import LinearSVC
3939
>>> clf = LinearSVC()
40-
>>> clf.fit(X_resampled, y_resampled) # doctest: +ELLIPSIS
40+
>>> clf.fit(X_resampled, y_resampled) # doctest : +ELLIPSIS
4141
LinearSVC(...)
4242

4343
In the figure below, we compare the decision functions of a classifier trained
@@ -67,12 +67,12 @@ can be used in the same manner::
6767

6868
>>> from imblearn.over_sampling import SMOTE, ADASYN
6969
>>> X_resampled, y_resampled = SMOTE().fit_sample(X, y)
70-
>>> print(Counter(y_resampled)) # doctest: +SKIP
71-
Counter({2: 4674, 1: 4674, 0: 4674})
70+
>>> print(sorted(Counter(y_resampled).items()))
71+
[(0, 4674), (1, 4674), (2, 4674)]
7272
>>> clf_smote = LinearSVC().fit(X_resampled, y_resampled)
7373
>>> X_resampled, y_resampled = ADASYN().fit_sample(X, y)
74-
>>> print(Counter(y_resampled))
75-
Counter({2: 4674, 0: 4673, 1: 4662})
74+
>>> print(sorted(Counter(y_resampled).items()))
75+
[(0, 4673), (1, 4662), (2, 4674)]
7676
>>> clf_adasyn = LinearSVC().fit(X_resampled, y_resampled)
7777

7878
The figure below illustrates the major difference of the different over-sampling
@@ -132,8 +132,8 @@ available: (i) ``'borderline1'``, (ii) ``'borderline2'``, and (iii) ``'svm'``::
132132

133133
>>> from imblearn.over_sampling import SMOTE, ADASYN
134134
>>> X_resampled, y_resampled = SMOTE(kind='borderline1').fit_sample(X, y)
135-
>>> print(Counter(y_resampled)) # doctest: +SKIP
136-
Counter({2: 4674, 1: 4674, 0: 4674})
135+
>>> print(sorted(Counter(y_resampled).items()))
136+
[(0, 4674), (1, 4674), (2, 4674)]
137137

138138
See :ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`
139139
to see a comparison between the different over-sampling methods.

doc/problem_statement.rst

Lines changed: 0 additions & 20 deletions
This file was deleted.

doc/under_sampling.rst

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ K-means method instead of the original samples::
2828
... n_clusters_per_class=1,
2929
... weights=[0.01, 0.05, 0.94],
3030
... class_sep=0.8, random_state=0)
31-
>>> print(Counter(y))
32-
Counter({2: 4674, 1: 262, 0: 64})
31+
>>> print(sorted(Counter(y).items()))
32+
[(0, 64), (1, 262), (2, 4674)]
3333
>>> from imblearn.under_sampling import ClusterCentroids
3434
>>> cc = ClusterCentroids(random_state=0)
3535
>>> X_resampled, y_resampled = cc.fit_sample(X, y)
36-
>>> print(Counter(y_resampled))
37-
Counter({0: 64, 1: 64, 2: 64})
36+
>>> print(sorted(Counter(y_resampled).items()))
37+
[(0, 64), (1, 64), (2, 64)]
3838

3939
The figure below illustrates such under-sampling.
4040

@@ -49,6 +49,12 @@ your data are grouped into clusters. In addition, the number of centroids
4949
should be set such that the under-sampled clusters are representative of the
5050
original one.
5151

52+
.. warning::
53+
54+
:class:`ClusterCentroids` supports sparse matrices. However, the new samples
55+
generated are not specifically sparse. Therefore, even if the resulting
56+
matrix will be sparse, the algorithm will be inefficient in this regard.
57+
5258
See :ref:`sphx_glr_auto_examples_under-sampling_plot_cluster_centroids.py` and
5359
:ref:`sphx_glr_auto_examples_under-sampling_plot_comparison_under_sampling.py`.
5460

@@ -77,8 +83,8 @@ randomly selecting a subset of data for the targeted classes::
7783
>>> from imblearn.under_sampling import RandomUnderSampler
7884
>>> rus = RandomUnderSampler(random_state=0)
7985
>>> X_resampled, y_resampled = rus.fit_sample(X, y)
80-
>>> print(Counter(y_resampled))
81-
Counter({0: 64, 1: 64, 2: 64})
86+
>>> print(sorted(Counter(y_resampled).items()))
87+
[(0, 64), (1, 64), (2, 64)]
8288

8389
.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_002.png
8490
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html
@@ -108,8 +114,8 @@ be selected with the parameter ``version``::
108114
>>> from imblearn.under_sampling import NearMiss
109115
>>> nm1 = NearMiss(random_state=0, version=1)
110116
>>> X_resampled_nm1, y_resampled = nm1.fit_sample(X, y)
111-
>>> print(Counter(y_resampled))
112-
Counter({0: 64, 1: 64, 2: 64})
117+
>>> print(sorted(Counter(y_resampled).items()))
118+
[(0, 64), (1, 64), (2, 64)]
113119

114120
As later stated in the next section, :class:`NearMiss` heuristic rules are
115121
based on nearest neighbors algorithm. Therefore, the parameters ``n_neighbors``
@@ -238,13 +244,13 @@ available: (i) the majority (i.e., ``kind_sel='mode'``) or (ii) all (i.e.,
238244
``kind_sel='all'``) the nearest-neighbors have to belong to the same class than
239245
the sample inspected to keep it in the dataset::
240246

241-
>>> Counter(y)
242-
Counter({2: 4674, 1: 262, 0: 64})
247+
>>> sorted(Counter(y).items())
248+
[(0, 64), (1, 262), (2, 4674)]
243249
>>> from imblearn.under_sampling import EditedNearestNeighbours
244250
>>> enn = EditedNearestNeighbours(random_state=0)
245251
>>> X_resampled, y_resampled = enn.fit_sample(X, y)
246-
>>> print(Counter(y_resampled))
247-
Counter({2: 4568, 1: 213, 0: 64})
252+
>>> print(sorted(Counter(y_resampled).items()))
253+
[(0, 64), (1, 213), (2, 4568)]
248254

249255
The parameter ``n_neighbors`` allows to give a classifier subclassed from
250256
``KNeighborsMixin`` from scikit-learn to find the nearest neighbors and make
@@ -257,8 +263,8 @@ Generally, repeating the algorithm will delete more data::
257263
>>> from imblearn.under_sampling import RepeatedEditedNearestNeighbours
258264
>>> renn = RepeatedEditedNearestNeighbours(random_state=0)
259265
>>> X_resampled, y_resampled = renn.fit_sample(X, y)
260-
>>> print(Counter(y_resampled))
261-
Counter({2: 4551, 1: 208, 0: 64})
266+
>>> print(sorted(Counter(y_resampled).items()))
267+
[(0, 64), (1, 208), (2, 4551)]
262268

263269
:class:`AllKNN` differs from the previous
264270
:class:`RepeatedEditedNearestNeighbours` since the number of neighbors of the
@@ -267,8 +273,8 @@ internal nearest neighbors algorithm is increased at each iteration::
267273
>>> from imblearn.under_sampling import AllKNN
268274
>>> allknn = AllKNN(random_state=0)
269275
>>> X_resampled, y_resampled = allknn.fit_sample(X, y)
270-
>>> print(Counter(y_resampled))
271-
Counter({2: 4601, 1: 220, 0: 64})
276+
>>> print(sorted(Counter(y_resampled).items()))
277+
[(0, 64), (1, 220), (2, 4601)]
272278

273279
In the example below, it can be seen that the three algorithms have similar
274280
impact by cleaning noisy samples next to the boundaries of the classes.
@@ -305,8 +311,8 @@ The :class:`CondensedNearestNeighbour` can be used in the following manner::
305311
>>> from imblearn.under_sampling import CondensedNearestNeighbour
306312
>>> cnn = CondensedNearestNeighbour(random_state=0)
307313
>>> X_resampled, y_resampled = cnn.fit_sample(X, y)
308-
>>> print(Counter(y_resampled))
309-
Counter({2: 116, 0: 64, 1: 25})
314+
>>> print(sorted(Counter(y_resampled).items()))
315+
[(0, 64), (1, 24), (2, 115)]
310316

311317
However as illustrated in the figure below, :class:`CondensedNearestNeighbour`
312318
is sensitive to noise and will add noisy samples.
@@ -320,8 +326,8 @@ used as::
320326
>>> from imblearn.under_sampling import OneSidedSelection
321327
>>> oss = OneSidedSelection(random_state=0)
322328
>>> X_resampled, y_resampled = oss.fit_sample(X, y)
323-
>>> print(Counter(y_resampled))
324-
Counter({2: 4403, 1: 174, 0: 64})
329+
>>> print(sorted(Counter(y_resampled).items()))
330+
[(0, 64), (1, 174), (2, 4403)]
325331

326332
Our implementation offer to set the number of seeds to put in the set :math:`C`
327333
originally by setting the parameter ``n_seeds_S``.
@@ -334,8 +340,8 @@ neighbors classifier. The class can be used as::
334340
>>> from imblearn.under_sampling import NeighbourhoodCleaningRule
335341
>>> ncr = NeighbourhoodCleaningRule(random_state=0)
336342
>>> X_resampled, y_resampled = ncr.fit_sample(X, y)
337-
>>> print(Counter(y_resampled))
338-
Counter({2: 4666, 1: 234, 0: 64})
343+
>>> print(sorted(Counter(y_resampled).items()))
344+
[(0, 64), (1, 234), (2, 4666)]
339345

340346
.. image:: ./auto_examples/under-sampling/images/sphx_glr_plot_comparison_under_sampling_005.png
341347
:target: ./auto_examples/under-sampling/plot_comparison_under_sampling.html
@@ -362,8 +368,8 @@ removed. The class can be used as::
362368
>>> iht = InstanceHardnessThreshold(random_state=0,
363369
... estimator=LogisticRegression())
364370
>>> X_resampled, y_resampled = iht.fit_sample(X, y)
365-
>>> print(Counter(y_resampled))
366-
Counter({0: 64, 1: 64, 2: 64})
371+
>>> print(sorted(Counter(y_resampled).items()))
372+
[(0, 64), (1, 64), (2, 64)]
367373

368374
This class has 2 important parameters. ``estimator`` will accept any
369375
scikit-learn classifier which has a method ``predict_proba``. The classifier

0 commit comments

Comments
 (0)