Skip to content

Commit fe51ca3

Browse files
committed
EXA: improve FunctionTransformer example
1 parent e49c30a commit fe51ca3

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

examples/plot_outlier_rejections.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,21 @@
3030

3131

3232
def plot_scatter(X, y, title):
33+
"""Function to plot some data as a scatter plot."""
3334
plt.figure()
3435
plt.scatter(X[y == 1, 0], X[y == 1, 1], label='Class #1')
3536
plt.scatter(X[y == 0, 0], X[y == 0, 1], label='Class #0')
3637
plt.legend()
3738
plt.title(title)
3839

40+
##############################################################################
41+
# Toy data generation
42+
##############################################################################
43+
44+
##############################################################################
45+
# We are generating some non Gaussian data set contaminated with some unform
46+
# noise.
3947

40-
# Generate contaminated training data
4148
moons, _ = make_moons(n_samples=500, noise=0.05)
4249
blobs, _ = make_blobs(n_samples=500, centers=[(-0.75, 2.25),
4350
(1.0, 2.0)],
@@ -51,7 +58,9 @@ def plot_scatter(X, y, title):
5158

5259
plot_scatter(X_train, y_train, 'Training dataset')
5360

54-
# Generate non-contaminated testing data
61+
##############################################################################
62+
# We will generate some cleaned test data without outliers.
63+
5564
moons, _ = make_moons(n_samples=50, noise=0.05)
5665
blobs, _ = make_blobs(n_samples=50, centers=[(-0.75, 2.25),
5766
(1.0, 2.0)],
@@ -62,8 +71,19 @@ def plot_scatter(X, y, title):
6271

6372
plot_scatter(X_test, y_test, 'Testing dataset')
6473

74+
##############################################################################
75+
# How to use the :class:`imblearn.FunctionSampler`
76+
##############################################################################
77+
78+
##############################################################################
79+
# We first define a function which will use
80+
# :class:`sklearn.ensemble.IsolationForest` to eliminate some outliers from
81+
# our dataset during training. The function passed to the
82+
# :class:`imblearn.FunctionSampler` will be called when using the method
83+
# ``fit_resample``.
6584

6685
def outlier_rejection(X, y):
86+
"""This will be our function used to resample our dataset."""
6787
model = IsolationForest(max_samples=100,
6888
contamination=0.4,
6989
random_state=rng)
@@ -76,6 +96,14 @@ def outlier_rejection(X, y):
7696
X_inliers, y_inliers = reject_sampler.fit_resample(X_train, y_train)
7797
plot_scatter(X_inliers, y_inliers, 'Training data without outliers')
7898

99+
##############################################################################
100+
# Integrate it within a pipeline
101+
##############################################################################
102+
103+
##############################################################################
104+
# By elimnating outliers before the training, the classifier will be less
105+
# affected during the prediction.
106+
79107
pipe = make_pipeline(FunctionSampler(func=outlier_rejection),
80108
LogisticRegression(random_state=rng))
81109
y_pred = pipe.fit(X_train, y_train).predict(X_test)

0 commit comments

Comments
 (0)