Skip to content

Commit 24f4973

Browse files
authored
EHN accept one-vs-all encoding for labels (#410)
1 parent d85c17a commit 24f4973

File tree

9 files changed

+164
-26
lines changed

9 files changed

+164
-26
lines changed

doc/whats_new/v0.0.4.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ Enhancement
1515
- Document the metrics to evaluate models on imbalanced dataset. :issue:`367`
1616
by :user:`Guillaume Lemaitre <glemaitre>`.
1717

18+
- Add support for one-vs-all encoded target to support keras. :issue:`409` by
19+
:user:`Guillaume Lemaitre <glemaitre>`.
20+
1821
Bug fixes
1922
.........
2023

imblearn/base.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
import logging
1010
from abc import ABCMeta, abstractmethod
1111

12+
import numpy as np
13+
1214
from sklearn.base import BaseEstimator
1315
from sklearn.externals import six
16+
from sklearn.preprocessing import label_binarize
1417
from sklearn.utils import check_X_y
1518
from sklearn.utils.validation import check_is_fitted
1619

@@ -54,14 +57,23 @@ def sample(self, X, y):
5457
The corresponding label of `X_resampled`
5558
5659
"""
57-
5860
# Check the consistency of X and y
61+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
5962
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
6063

6164
check_is_fitted(self, 'ratio_')
6265
self._check_X_y(X, y)
6366

64-
return self._sample(X, y)
67+
output = self._sample(X, y)
68+
69+
if binarize_y:
70+
y_sampled = label_binarize(output[1], np.unique(y))
71+
if len(output) == 2:
72+
return output[0], y_sampled
73+
else:
74+
return output[0], y_sampled, output[2]
75+
else:
76+
return output
6577

6678
def fit_sample(self, X, y):
6779
"""Fit the statistics and resample the data directly.
@@ -152,8 +164,8 @@ def fit(self, X, y):
152164
Return self.
153165
154166
"""
155-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
156167
y = check_target_type(y)
168+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
157169
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
158170
# self.sampling_type is already checked in check_ratio
159171
self.ratio_ = check_ratio(self.ratio, y, self._sampling_type)
@@ -232,25 +244,20 @@ def __init__(self, func=None, accept_sparse=True, kw_args=None):
232244
self.kw_args = kw_args
233245
self.logger = logging.getLogger(__name__)
234246

235-
def _check_X_y(self, X, y):
236-
if self.accept_sparse:
237-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
238-
else:
239-
X, y = check_X_y(X, y, accept_sparse=False)
240-
y = check_target_type(y)
241-
242-
return X, y
243-
244247
def fit(self, X, y):
245-
X, y = self._check_X_y(X, y)
248+
y = check_target_type(y)
249+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']
250+
if self.accept_sparse else False)
246251
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
247252
# when using a sampler, ratio_ is supposed to exist after fit
248253
self.ratio_ = 'is_fitted'
249254

250255
return self
251256

252257
def _sample(self, X, y, func=None, kw_args=None):
253-
X, y = self._check_X_y(X, y)
258+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
259+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']
260+
if self.accept_sparse else False)
254261
check_is_fitted(self, 'ratio_')
255262
X_hash, y_hash = hash_X_y(X, y)
256263
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
@@ -259,7 +266,16 @@ def _sample(self, X, y, func=None, kw_args=None):
259266
if func is None:
260267
func = _identity
261268

262-
return func(X, y, **(kw_args if self.kw_args else {}))
269+
output = func(X, y, **(kw_args if self.kw_args else {}))
270+
271+
if binarize_y:
272+
y_sampled = label_binarize(output[1], np.unique(y))
273+
if len(output) == 2:
274+
return output[0], y_sampled
275+
else:
276+
return output[0], y_sampled, output[2]
277+
else:
278+
return output
263279

264280
def sample(self, X, y):
265281
return self._sample(X, y, func=self.func, kw_args=self.kw_args)

imblearn/combine/smote_enn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def fit(self, X, y):
144144
Return self.
145145
146146
"""
147-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
148147
y = check_target_type(y)
148+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
149149
self.ratio_ = self.ratio
150150
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
151151

imblearn/combine/smote_tomek.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from __future__ import division
99

1010
import logging
11-
import warnings
1211

1312
from sklearn.utils import check_X_y
1413

@@ -153,8 +152,8 @@ def fit(self, X, y):
153152
Return self.
154153
155154
"""
156-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
157155
y = check_target_type(y)
156+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
158157
self.ratio_ = self.ratio
159158
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
160159

imblearn/ensemble/balance_cascade.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sklearn.model_selection import cross_val_predict
1515

1616
from .base import BaseEnsembleSampler
17-
from ..utils import check_ratio
17+
from ..utils import check_ratio, check_target_type
1818

1919

2020
class BalanceCascade(BaseEnsembleSampler):
@@ -137,6 +137,7 @@ def fit(self, X, y):
137137
138138
"""
139139
super(BalanceCascade, self).fit(X, y)
140+
y = check_target_type(y)
140141
self.ratio_ = check_ratio(self.ratio, y, 'under-sampling')
141142
return self
142143

imblearn/ensemble/base.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
55
# License: MIT
66

7+
import numpy as np
8+
9+
from sklearn.preprocessing import label_binarize
10+
from sklearn.utils import check_X_y
11+
from sklearn.utils.validation import check_is_fitted
12+
713
from ..base import BaseSampler
14+
from ..utils import check_target_type
815

916

1017
class BaseEnsembleSampler(BaseSampler):
@@ -15,3 +22,46 @@ class BaseEnsembleSampler(BaseSampler):
1522
"""
1623

1724
_sampling_type = 'ensemble'
25+
26+
def sample(self, X, y):
27+
"""Resample the dataset.
28+
29+
Parameters
30+
----------
31+
X : {array-like, sparse matrix}, shape (n_samples, n_features)
32+
Matrix containing the data which have to be sampled.
33+
34+
y : array-like, shape (n_samples,)
35+
Corresponding label for each sample in X.
36+
37+
Returns
38+
-------
39+
X_resampled : {ndarray, sparse matrix}, shape \
40+
(n_subset, n_samples_new, n_features)
41+
The array containing the resampled data.
42+
43+
y_resampled : ndarray, shape (n_subset, n_samples_new)
44+
The corresponding label of `X_resampled`
45+
46+
"""
47+
# Ensemble are a bit specific since they are returning an array of
48+
# resampled arrays.
49+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
50+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
51+
52+
check_is_fitted(self, 'ratio_')
53+
self._check_X_y(X, y)
54+
55+
output = self._sample(X, y)
56+
57+
if binarize_y:
58+
y_resampled = output[1]
59+
classes = np.unique(y)
60+
y_resampled_encoded = np.array([label_binarize(batch_y, classes)
61+
for batch_y in y_resampled])
62+
if len(output) == 2:
63+
return output[0], y_resampled_encoded
64+
else:
65+
return output[0], y_resampled_encoded, output[2]
66+
else:
67+
return output

imblearn/utils/estimator_checks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
from sklearn.datasets import make_classification
2121
from sklearn.cluster import KMeans
22+
from sklearn.preprocessing import label_binarize
2223
from sklearn.utils.estimator_checks import check_estimator \
2324
as sklearn_check_estimator, check_parameters_default_constructible
2425
from sklearn.exceptions import NotFittedError
2526
from sklearn.utils.testing import assert_allclose
2627
from sklearn.utils.testing import set_random_state
28+
from sklearn.utils.multiclass import type_of_target
2729

2830
from imblearn.over_sampling.base import BaseOverSampler
2931
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
@@ -44,6 +46,7 @@ def _yield_sampler_checks(name, Estimator):
4446
yield check_samplers_ratio_fit_sample
4547
yield check_samplers_sparse
4648
yield check_samplers_pandas
49+
yield check_samplers_multiclass_ova
4750

4851

4952
def _yield_all_checks(name, estimator):
@@ -253,3 +256,23 @@ def check_samplers_pandas(name, Sampler):
253256
X_res, y_res = sampler.fit_sample(X, y)
254257
assert_allclose(X_res_pd, X_res)
255258
assert_allclose(y_res_pd, y_res)
259+
260+
261+
def check_samplers_multiclass_ova(name, Sampler):
262+
# Check that multiclass target lead to the same results than OVA encoding
263+
X, y = make_classification(n_samples=1000, n_classes=3,
264+
n_informative=4, weights=[0.2, 0.3, 0.5],
265+
random_state=0)
266+
y_ova = label_binarize(y, np.unique(y))
267+
sampler = Sampler()
268+
set_random_state(sampler)
269+
X_res, y_res = sampler.fit_sample(X, y)
270+
X_res_ova, y_res_ova = sampler.fit_sample(X, y_ova)
271+
assert_allclose(X_res, X_res_ova)
272+
if issubclass(Sampler, BaseEnsembleSampler):
273+
for batch_y, batch_y_ova in zip(y_res, y_res_ova):
274+
assert type_of_target(batch_y_ova) == type_of_target(y_ova)
275+
assert_allclose(batch_y, batch_y_ova.argmax(axis=1))
276+
else:
277+
assert type_of_target(y_res_ova) == type_of_target(y_ova)
278+
assert_allclose(y_res, y_res_ova.argmax(axis=1))

imblearn/utils/tests/test_validation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
from collections import Counter
77

88
import numpy as np
9+
import pytest
910
from pytest import raises
1011

1112
from sklearn.neighbors.base import KNeighborsMixin
1213
from sklearn.neighbors import NearestNeighbors
1314
from sklearn.utils import check_random_state
1415
from sklearn.externals import joblib
16+
from sklearn.utils.testing import assert_array_equal
1517

1618
from imblearn.utils.testing import warns
1719
from imblearn.utils import check_neighbors_object
1820
from imblearn.utils import check_ratio
1921
from imblearn.utils import hash_X_y
22+
from imblearn.utils import check_target_type
2023

2124

2225
def test_check_neighbors_object():
@@ -35,6 +38,36 @@ def test_check_neighbors_object():
3538
check_neighbors_object(name, n_neighbors)
3639

3740

41+
@pytest.mark.parametrize(
42+
"target, output_target",
43+
[(np.array([0, 1, 1]), np.array([0, 1, 1])),
44+
(np.array([0, 1, 2]), np.array([0, 1, 2])),
45+
(np.array([[0, 1], [1, 0]]), np.array([1, 0]))]
46+
)
47+
def test_check_target_type(target, output_target):
48+
converted_target = check_target_type(target.astype(int))
49+
assert_array_equal(converted_target, output_target.astype(int))
50+
51+
52+
@pytest.mark.parametrize(
53+
"target, output_target, is_ova",
54+
[(np.array([0, 1, 1]), np.array([0, 1, 1]), False),
55+
(np.array([0, 1, 2]), np.array([0, 1, 2]), False),
56+
(np.array([[0, 1], [1, 0]]), np.array([1, 0]), True)]
57+
)
58+
def test_check_target_type_ova(target, output_target, is_ova):
59+
converted_target, binarize_target = check_target_type(
60+
target.astype(int), indicate_one_vs_all=True)
61+
assert_array_equal(converted_target, output_target.astype(int))
62+
assert binarize_target == is_ova
63+
64+
65+
def test_check_target_warning():
66+
target = np.arange(4).reshape((2, 2))
67+
with pytest.warns(UserWarning, match='should be of types'):
68+
check_target_type(target)
69+
70+
3871
def test_check_ratio_error():
3972
with raises(ValueError, match="'sampling_type' should be one of"):
4073
check_ratio('auto', np.array([1, 2, 3]), 'rnd')

imblearn/utils/validation.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling',
2121
'ensemble')
22-
TARGET_KIND = ('binary', 'multiclass')
22+
TARGET_KIND = ('binary', 'multiclass', 'multilabel-indicator')
2323

2424

2525
def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
@@ -54,29 +54,42 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
5454
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)
5555

5656

57-
def check_target_type(y):
57+
def check_target_type(y, indicate_one_vs_all=False):
5858
"""Check the target types to be conform to the current samplers.
5959
60-
The current samplers should be compatible with ``'binary'`` and
61-
``'multiclass'`` targets only.
60+
The current samplers should be compatible with ``'binary'``,
61+
``'multilabel-indicator'`` and ``'multiclass'`` targets only.
6262
6363
Parameters
6464
----------
6565
y : ndarray,
66-
The array containing the target
66+
The array containing the target.
67+
68+
indicate_one_vs_all : bool, optional
69+
Either to indicate if the targets are encoded in a one-vs-all fashion.
6770
6871
Returns
6972
-------
7073
y : ndarray,
7174
The returned target.
7275
76+
is_one_vs_all : bool, optional
77+
Indicate if the target was originally encoded in a one-vs-all fashion.
78+
Only returned if ``indicate_multilabel=True``.
79+
7380
"""
74-
if type_of_target(y) not in TARGET_KIND:
81+
type_y = type_of_target(y)
82+
if type_y not in TARGET_KIND:
7583
# FIXME: perfectly we should raise an error but the sklearn API does
7684
# not allow for it
7785
warnings.warn("'y' should be of types {} only. Got {} instead.".format(
7886
TARGET_KIND, type_of_target(y)))
79-
return y
87+
88+
if indicate_one_vs_all:
89+
return (y.argmax(axis=1) if type_y == 'multilabel-indicator' else y,
90+
type_y == 'multilabel-indicator')
91+
else:
92+
return y.argmax(axis=1) if type_y == 'multilabel-indicator' else y
8093

8194

8295
def hash_X_y(X, y, n_samples=10, n_features=5):

0 commit comments

Comments
 (0)