Skip to content

Commit 2940dbb

Browse files
committed
iter
1 parent 96ef7d4 commit 2940dbb

File tree

1 file changed

+313
-1
lines changed

1 file changed

+313
-1
lines changed

examples/applications/plot_impact_imbalanced_classes.py

Lines changed: 313 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from collections import Counter
3636

3737
classes_count = y.value_counts()
38-
print(f"Classes information:\n{classes_count}")
38+
classes_count
3939

4040
###############################################################################
4141
# This dataset is only slightly imbalanced. To better highlight the effect of
@@ -49,6 +49,7 @@
4949
classes_count.idxmin(): classes_count.max() // ratio
5050
}
5151
)
52+
y_res.value_counts()
5253

5354
###############################################################################
5455
# For the rest of the notebook, we will make a single split to get training
@@ -80,3 +81,314 @@
8081
y_pred = dummy_clf.predict(X_test)
8182
score = balanced_accuracy_score(y_test, y_pred)
8283
print(f"Balanced accuracy score of a dummy classifier: {score:.3f}")
84+
85+
###############################################################################
86+
# Strategies to learn from an imbalanced dataset
87+
###############################################################################
88+
89+
###############################################################################
90+
# We will first define an helper function which will train a given model
91+
# and compute both accuracy and balanced accuracy. The results will be stored
92+
# in a dataframe
93+
94+
import pandas as pd
95+
96+
97+
def evaluate_classifier(clf, df_scores, clf_name=None):
98+
from sklearn.pipeline import Pipeline
99+
if clf_name is None:
100+
if isinstance(clf, Pipeline):
101+
clf_name = clf[-1].__class__.__name__
102+
else:
103+
clf_name = clf.__class__.__name__
104+
acc = clf.fit(X_train, y_train).score(X_test, y_test)
105+
y_pred = clf.predict(X_test)
106+
bal_acc = balanced_accuracy_score(y_test, y_pred)
107+
clf_score = pd.DataFrame(
108+
{clf_name: [acc, bal_acc]},
109+
index=['Accuracy', 'Balanced accuracy']
110+
)
111+
df_scores = pd.concat([df_scores, clf_score], axis=1).round(decimals=3)
112+
return df_scores
113+
114+
115+
# Let's define an empty dataframe to store the results
116+
df_scores = pd.DataFrame()
117+
118+
###############################################################################
119+
# Dummy baseline
120+
# ..............
121+
#
122+
# Before to train a real machine learning model, we can store the results
123+
# obtained with our `DummyClassifier`.
124+
125+
df_scores = evaluate_classifier(dummy_clf, df_scores, "Dummy")
126+
df_scores
127+
128+
###############################################################################
129+
# Linear classifier baseline
130+
# ..........................
131+
#
132+
# We will create a machine learning pipeline using a `LogisticRegression`
133+
# classifier. In this regard, we will need to one-hot encode the categorical
134+
# columns and standardized the numerical columns before to inject the data into
135+
# the `LogisticRegression` classifier.
136+
#
137+
# First, we define our numerical and categorical pipelines.
138+
139+
from sklearn.impute import SimpleImputer
140+
from sklearn.preprocessing import StandardScaler
141+
from sklearn.preprocessing import OneHotEncoder
142+
from sklearn.pipeline import make_pipeline
143+
144+
num_pipe = make_pipeline(
145+
StandardScaler(), SimpleImputer(strategy="mean", add_indicator=True)
146+
)
147+
cat_pipe = make_pipeline(
148+
SimpleImputer(strategy="constant", fill_value="missing"),
149+
OneHotEncoder(handle_unknown="ignore")
150+
)
151+
152+
###############################################################################
153+
# Then, we can create a preprocessor which will dispatch the categorical
154+
# columns to the categorical pipeline and the numerical columns to the
155+
# numerical pipeline
156+
157+
import numpy as np
158+
from sklearn.compose import ColumnTransformer
159+
from sklearn.compose import make_column_selector as selector
160+
161+
preprocessor_linear = ColumnTransformer(
162+
[("num-pipe", num_pipe, selector(dtype_include=np.number)),
163+
("cat-pipe", cat_pipe, selector(dtype_include=pd.CategoricalDtype))]
164+
)
165+
166+
###############################################################################
167+
# Finally, we connect our preprocessor with our `LogisticRegression`. We can
168+
# then evaluate our model.
169+
170+
from sklearn.linear_model import LogisticRegression
171+
172+
lr_clf = make_pipeline(
173+
preprocessor_linear, LogisticRegression(max_iter=1000)
174+
)
175+
df_scores = evaluate_classifier(lr_clf, df_scores, "LR")
176+
df_scores
177+
178+
###############################################################################
179+
# We can see that our linear model is learning slightly better than our dummy
180+
# baseline. However, it is impacted by class imbalanced.
181+
#
182+
# We can verify that something similar is happening with a tree-based model
183+
# such as `RandomForestClassifier`. With this type of classifier, we will not
184+
# need to scale the numerical data, and we will only need to ordinal encode the
185+
# categorical data.
186+
187+
from sklearn.preprocessing import OrdinalEncoder
188+
from sklearn.ensemble import RandomForestClassifier
189+
190+
cat_pipe = make_pipeline(
191+
SimpleImputer(strategy="constant", fill_value="missing"),
192+
OrdinalEncoder()
193+
)
194+
195+
preprocessor_tree = ColumnTransformer(
196+
[("num-pipe", num_pipe, selector(dtype_include=np.number)),
197+
("cat-pipe", cat_pipe, selector(dtype_include=pd.CategoricalDtype))]
198+
)
199+
200+
rf_clf = make_pipeline(
201+
preprocessor_tree, RandomForestClassifier(random_state=42)
202+
)
203+
204+
df_scores = evaluate_classifier(rf_clf, df_scores, "RF")
205+
df_scores
206+
207+
###############################################################################
208+
# The `RandomForestClassifier` is as well affected by the class imbalanced,
209+
# slightly less than the linear model. Now, we will present different approach
210+
# to improve the performance of these 2 models.
211+
#
212+
# Use `class_weight`
213+
# ..................
214+
#
215+
# Most of the models in `scikit-learn` have a parameter `class_weight`. This
216+
# parameter will affect the computation of the loss in linear model or the
217+
# criterion in the tree-based model to penalize differently a false
218+
# classification from the minority and majority class. We can set
219+
# `class_weight="balanced"` such that the weight applied is inversely
220+
# proportional to the class frequency. We test this parametrization in both
221+
# linear model and tree-based model.
222+
223+
lr_clf.set_params(logisticregression__class_weight="balanced")
224+
df_scores = evaluate_classifier(
225+
lr_clf, df_scores, "LR with class weight"
226+
)
227+
df_scores
228+
229+
###############################################################################
230+
#
231+
232+
rf_clf.set_params(randomforestclassifier__class_weight="balanced")
233+
df_scores = evaluate_classifier(
234+
rf_clf, df_scores, "RF with class weight"
235+
)
236+
df_scores
237+
238+
###############################################################################
239+
# We can see that using `class_weight` was really effective for the linear
240+
# model, alleviating the issue of learning from imbalanced classes. However,
241+
# the `RandomForestClassifier` is still biased toward the majority class,
242+
# mainly due to the criterion which is not suited enough to fight the class
243+
# imbalance.
244+
#
245+
# Resample the training set during learning
246+
# .........................................
247+
#
248+
# Another way is to resample the training set by under-sampling or
249+
# over-sampling some of the samples. `imbalanced-learn` provides some samplers
250+
# to do such precessing.
251+
252+
from imblearn.pipeline import make_pipeline as make_pipeline_with_sampler
253+
from imblearn.under_sampling import RandomUnderSampler
254+
255+
lr_clf = make_pipeline_with_sampler(
256+
preprocessor_linear,
257+
RandomUnderSampler(random_state=42),
258+
LogisticRegression(max_iter=1000)
259+
)
260+
df_scores = evaluate_classifier(
261+
lr_clf, df_scores, "LR with under-sampling"
262+
)
263+
df_scores
264+
265+
###############################################################################
266+
#
267+
268+
rf_clf = make_pipeline_with_sampler(
269+
preprocessor_tree,
270+
RandomUnderSampler(random_state=42),
271+
RandomForestClassifier(random_state=42)
272+
)
273+
274+
df_scores = evaluate_classifier(
275+
rf_clf, df_scores, "RF with under-sampling"
276+
)
277+
df_scores
278+
279+
###############################################################################
280+
# Applying a random under-sampler before to train the linear model or random
281+
# forest, allows to not focus on the majority class at the cost of making more
282+
# mistake for samples in the majority class (i.e. decreased accuracy).
283+
#
284+
# We could apply any type of samplers and find which sampler is working best
285+
# on the current dataset.
286+
#
287+
# Instead, we will present another way by using classifiers which will apply
288+
# sampling internally.
289+
#
290+
# Use of `BalancedRandomForestClassifier` and `BalancedBaggingClassifier`
291+
# .......................................................................
292+
#
293+
# We already show that random under-sampling can be effective on decision tree.
294+
# However, instead of under-sampling once the dataset, one could under-sample
295+
# the original dataset before to take a bootstrap sample. This is the base of
296+
# the `BalancedRandomForestClassifier` and `BalancedBaggingClassifier`.
297+
298+
from imblearn.ensemble import BalancedRandomForestClassifier
299+
300+
rf_clf = make_pipeline(
301+
preprocessor_tree,
302+
BalancedRandomForestClassifier(random_state=42)
303+
)
304+
305+
df_scores = evaluate_classifier(rf_clf, df_scores, "Balanced RF")
306+
df_scores
307+
308+
###############################################################################
309+
# The performance with the `BalancedRandomForestClassifier` are better than
310+
# applying a single random under-sampling. We will use a gradient-boosting
311+
# classifier within a `BalancedBaggingClassifier`.
312+
313+
from sklearn.experimental import enable_hist_gradient_boosting
314+
from sklearn.ensemble import HistGradientBoostingClassifier
315+
from imblearn.ensemble import BalancedBaggingClassifier
316+
317+
bag_clf = make_pipeline(
318+
preprocessor_tree,
319+
BalancedBaggingClassifier(
320+
base_estimator=HistGradientBoostingClassifier(random_state=42),
321+
n_estimators=10, random_state=42
322+
)
323+
)
324+
325+
df_scores = evaluate_classifier(
326+
bag_clf, df_scores, "Balanced bagging"
327+
)
328+
df_scores
329+
330+
###############################################################################
331+
# This last approach is the most effective. The different under-sampling allows
332+
# to bring some diversity for the different GBDT to learn and not focus on a
333+
# portion of the majority class.
334+
#
335+
# We will repeat the same experiment but a ratio of 100:1 and make a similar
336+
# analysis.
337+
338+
###############################################################################
339+
# Increase imbalanced ratio
340+
###############################################################################
341+
342+
ratio = 100
343+
df_res, y_res = make_imbalance(
344+
df, y, sampling_strategy={
345+
classes_count.idxmin(): classes_count.max() // ratio
346+
}
347+
)
348+
X_train, X_test, y_train, y_test = train_test_split(
349+
df_res, y_res, stratify=y_res, random_state=42
350+
)
351+
352+
df_scores = pd.DataFrame()
353+
df_scores = evaluate_classifier(dummy_clf, df_scores, "Dummy")
354+
lr_clf = make_pipeline(
355+
preprocessor_linear, LogisticRegression(max_iter=1000)
356+
)
357+
df_scores = evaluate_classifier(lr_clf, df_scores, "LR")
358+
rf_clf = make_pipeline(
359+
preprocessor_tree, RandomForestClassifier(random_state=42)
360+
)
361+
df_scores = evaluate_classifier(rf_clf, df_scores, "RF")
362+
lr_clf.set_params(logisticregression__class_weight="balanced")
363+
df_scores = evaluate_classifier(
364+
lr_clf, df_scores, "LR with class weight"
365+
)
366+
rf_clf.set_params(randomforestclassifier__class_weight="balanced")
367+
df_scores = evaluate_classifier(
368+
rf_clf, df_scores, "RF with class weight"
369+
)
370+
lr_clf = make_pipeline_with_sampler(
371+
preprocessor_linear,
372+
RandomUnderSampler(random_state=42),
373+
LogisticRegression(max_iter=1000)
374+
)
375+
df_scores = evaluate_classifier(
376+
lr_clf, df_scores, "LR with under-sampling"
377+
)
378+
rf_clf = make_pipeline_with_sampler(
379+
preprocessor_tree,
380+
RandomUnderSampler(random_state=42),
381+
RandomForestClassifier(random_state=42)
382+
)
383+
df_scores = evaluate_classifier(
384+
rf_clf, df_scores, "RF with under-sampling"
385+
)
386+
rf_clf = make_pipeline(
387+
preprocessor_tree,
388+
BalancedRandomForestClassifier(random_state=42)
389+
)
390+
df_scores = evaluate_classifier(rf_clf, df_scores)
391+
df_scores = evaluate_classifier(
392+
bag_clf, df_scores, "Balanced bagging"
393+
)
394+
df_scores

0 commit comments

Comments
 (0)