Skip to content

EHN accept one-vs-all encoding for labels #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v0.0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ Enhancement
- Document the metrics to evaluate models on imbalanced dataset. :issue:`367`
by :user:`Guillaume Lemaitre <glemaitre>`.

- Add support for one-vs-all encoded target to support keras. :issue:`409` by
:user:`Guillaume Lemaitre <glemaitre>`.

Bug fixes
.........

Expand Down
46 changes: 31 additions & 15 deletions imblearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
import logging
from abc import ABCMeta, abstractmethod

import numpy as np

from sklearn.base import BaseEstimator
from sklearn.externals import six
from sklearn.preprocessing import label_binarize
from sklearn.utils import check_X_y
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -54,14 +57,23 @@ def sample(self, X, y):
The corresponding label of `X_resampled`

"""

# Check the consistency of X and y
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])

check_is_fitted(self, 'ratio_')
self._check_X_y(X, y)

return self._sample(X, y)
output = self._sample(X, y)

if binarize_y:
y_sampled = label_binarize(output[1], np.unique(y))
if len(output) == 2:
return output[0], y_sampled
else:
return output[0], y_sampled, output[2]
else:
return output

def fit_sample(self, X, y):
"""Fit the statistics and resample the data directly.
Expand Down Expand Up @@ -152,8 +164,8 @@ def fit(self, X, y):
Return self.

"""
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
# self.sampling_type is already checked in check_ratio
self.ratio_ = check_ratio(self.ratio, y, self._sampling_type)
Expand Down Expand Up @@ -232,25 +244,20 @@ def __init__(self, func=None, accept_sparse=True, kw_args=None):
self.kw_args = kw_args
self.logger = logging.getLogger(__name__)

def _check_X_y(self, X, y):
if self.accept_sparse:
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
else:
X, y = check_X_y(X, y, accept_sparse=False)
y = check_target_type(y)

return X, y

def fit(self, X, y):
X, y = self._check_X_y(X, y)
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']
if self.accept_sparse else False)
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
# when using a sampler, ratio_ is supposed to exist after fit
self.ratio_ = 'is_fitted'

return self

def _sample(self, X, y, func=None, kw_args=None):
X, y = self._check_X_y(X, y)
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']
if self.accept_sparse else False)
check_is_fitted(self, 'ratio_')
X_hash, y_hash = hash_X_y(X, y)
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
Expand All @@ -259,7 +266,16 @@ def _sample(self, X, y, func=None, kw_args=None):
if func is None:
func = _identity

return func(X, y, **(kw_args if self.kw_args else {}))
output = func(X, y, **(kw_args if self.kw_args else {}))

if binarize_y:
y_sampled = label_binarize(output[1], np.unique(y))
if len(output) == 2:
return output[0], y_sampled
else:
return output[0], y_sampled, output[2]
else:
return output

def sample(self, X, y):
return self._sample(X, y, func=self.func, kw_args=self.kw_args)
2 changes: 1 addition & 1 deletion imblearn/combine/smote_enn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def fit(self, X, y):
Return self.

"""
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.ratio_ = self.ratio
self.X_hash_, self.y_hash_ = hash_X_y(X, y)

Expand Down
3 changes: 1 addition & 2 deletions imblearn/combine/smote_tomek.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from __future__ import division

import logging
import warnings

from sklearn.utils import check_X_y

Expand Down Expand Up @@ -153,8 +152,8 @@ def fit(self, X, y):
Return self.

"""
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
y = check_target_type(y)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
self.ratio_ = self.ratio
self.X_hash_, self.y_hash_ = hash_X_y(X, y)

Expand Down
3 changes: 2 additions & 1 deletion imblearn/ensemble/balance_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sklearn.model_selection import cross_val_predict

from .base import BaseEnsembleSampler
from ..utils import check_ratio
from ..utils import check_ratio, check_target_type


class BalanceCascade(BaseEnsembleSampler):
Expand Down Expand Up @@ -137,6 +137,7 @@ def fit(self, X, y):

"""
super(BalanceCascade, self).fit(X, y)
y = check_target_type(y)
self.ratio_ = check_ratio(self.ratio, y, 'under-sampling')
return self

Expand Down
50 changes: 50 additions & 0 deletions imblearn/ensemble/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT

import numpy as np

from sklearn.preprocessing import label_binarize
from sklearn.utils import check_X_y
from sklearn.utils.validation import check_is_fitted

from ..base import BaseSampler
from ..utils import check_target_type


class BaseEnsembleSampler(BaseSampler):
Expand All @@ -15,3 +22,46 @@ class BaseEnsembleSampler(BaseSampler):
"""

_sampling_type = 'ensemble'

def sample(self, X, y):
"""Resample the dataset.

Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.

y : array-like, shape (n_samples,)
Corresponding label for each sample in X.

Returns
-------
X_resampled : {ndarray, sparse matrix}, shape \
(n_subset, n_samples_new, n_features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indenting

The array containing the resampled data.

y_resampled : ndarray, shape (n_subset, n_samples_new)
The corresponding label of `X_resampled`

"""
# Ensemble are a bit specific since they are returning an array of
# resampled arrays.
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])

check_is_fitted(self, 'ratio_')
self._check_X_y(X, y)

output = self._sample(X, y)

if binarize_y:
y_resampled = output[1]
classes = np.unique(y)
y_resampled_encoded = np.array([label_binarize(batch_y, classes)
for batch_y in y_resampled])
if len(output) == 2:
return output[0], y_resampled_encoded
else:
return output[0], y_resampled_encoded, output[2]
else:
return output
23 changes: 23 additions & 0 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

from sklearn.datasets import make_classification
from sklearn.cluster import KMeans
from sklearn.preprocessing import label_binarize
from sklearn.utils.estimator_checks import check_estimator \
as sklearn_check_estimator, check_parameters_default_constructible
from sklearn.exceptions import NotFittedError
from sklearn.utils.testing import assert_allclose
from sklearn.utils.testing import set_random_state
from sklearn.utils.multiclass import type_of_target

from imblearn.over_sampling.base import BaseOverSampler
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
Expand All @@ -44,6 +46,7 @@ def _yield_sampler_checks(name, Estimator):
yield check_samplers_ratio_fit_sample
yield check_samplers_sparse
yield check_samplers_pandas
yield check_samplers_multiclass_ova


def _yield_all_checks(name, estimator):
Expand Down Expand Up @@ -253,3 +256,23 @@ def check_samplers_pandas(name, Sampler):
X_res, y_res = sampler.fit_sample(X, y)
assert_allclose(X_res_pd, X_res)
assert_allclose(y_res_pd, y_res)


def check_samplers_multiclass_ova(name, Sampler):
# Check that multiclass target lead to the same results than OVA encoding
X, y = make_classification(n_samples=1000, n_classes=3,
n_informative=4, weights=[0.2, 0.3, 0.5],
random_state=0)
y_ova = label_binarize(y, np.unique(y))
sampler = Sampler()
set_random_state(sampler)
X_res, y_res = sampler.fit_sample(X, y)
X_res_ova, y_res_ova = sampler.fit_sample(X, y_ova)
assert_allclose(X_res, X_res_ova)
if issubclass(Sampler, BaseEnsembleSampler):
for batch_y, batch_y_ova in zip(y_res, y_res_ova):
assert type_of_target(batch_y_ova) == type_of_target(y_ova)
assert_allclose(batch_y, batch_y_ova.argmax(axis=1))
else:
assert type_of_target(y_res_ova) == type_of_target(y_ova)
assert_allclose(y_res, y_res_ova.argmax(axis=1))
33 changes: 33 additions & 0 deletions imblearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from collections import Counter

import numpy as np
import pytest
from pytest import raises

from sklearn.neighbors.base import KNeighborsMixin
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_random_state
from sklearn.externals import joblib
from sklearn.utils.testing import assert_array_equal

from imblearn.utils.testing import warns
from imblearn.utils import check_neighbors_object
from imblearn.utils import check_ratio
from imblearn.utils import hash_X_y
from imblearn.utils import check_target_type


def test_check_neighbors_object():
Expand All @@ -35,6 +38,36 @@ def test_check_neighbors_object():
check_neighbors_object(name, n_neighbors)


@pytest.mark.parametrize(
"target, output_target",
[(np.array([0, 1, 1]), np.array([0, 1, 1])),
(np.array([0, 1, 2]), np.array([0, 1, 2])),
(np.array([[0, 1], [1, 0]]), np.array([1, 0]))]
)
def test_check_target_type(target, output_target):
converted_target = check_target_type(target.astype(int))
assert_array_equal(converted_target, output_target.astype(int))


@pytest.mark.parametrize(
"target, output_target, is_ova",
[(np.array([0, 1, 1]), np.array([0, 1, 1]), False),
(np.array([0, 1, 2]), np.array([0, 1, 2]), False),
(np.array([[0, 1], [1, 0]]), np.array([1, 0]), True)]
)
def test_check_target_type_ova(target, output_target, is_ova):
converted_target, binarize_target = check_target_type(
target.astype(int), indicate_one_vs_all=True)
assert_array_equal(converted_target, output_target.astype(int))
assert binarize_target == is_ova


def test_check_target_warning():
target = np.arange(4).reshape((2, 2))
with pytest.warns(UserWarning, match='should be of types'):
check_target_type(target)


def test_check_ratio_error():
with raises(ValueError, match="'sampling_type' should be one of"):
check_ratio('auto', np.array([1, 2, 3]), 'rnd')
Expand Down
27 changes: 20 additions & 7 deletions imblearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

SAMPLING_KIND = ('over-sampling', 'under-sampling', 'clean-sampling',
'ensemble')
TARGET_KIND = ('binary', 'multiclass')
TARGET_KIND = ('binary', 'multiclass', 'multilabel-indicator')


def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
Expand Down Expand Up @@ -54,29 +54,42 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0):
raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object)


def check_target_type(y):
def check_target_type(y, indicate_one_vs_all=False):
"""Check the target types to be conform to the current samplers.

The current samplers should be compatible with ``'binary'`` and
``'multiclass'`` targets only.
The current samplers should be compatible with ``'binary'``,
``'multilabel-indicator'`` and ``'multiclass'`` targets only.

Parameters
----------
y : ndarray,
The array containing the target
The array containing the target.

indicate_one_vs_all : bool, optional
Either to indicate if the targets are encoded in a one-vs-all fashion.

Returns
-------
y : ndarray,
The returned target.

is_one_vs_all : bool, optional
Indicate if the target was originally encoded in a one-vs-all fashion.
Only returned if ``indicate_multilabel=True``.

"""
if type_of_target(y) not in TARGET_KIND:
type_y = type_of_target(y)
if type_y not in TARGET_KIND:
# FIXME: perfectly we should raise an error but the sklearn API does
# not allow for it
warnings.warn("'y' should be of types {} only. Got {} instead.".format(
TARGET_KIND, type_of_target(y)))
return y

if indicate_one_vs_all:
return (y.argmax(axis=1) if type_y == 'multilabel-indicator' else y,
type_y == 'multilabel-indicator')
else:
return y.argmax(axis=1) if type_y == 'multilabel-indicator' else y


def hash_X_y(X, y, n_samples=10, n_features=5):
Expand Down