Skip to content

adding parameters for device context and patching of Scikit-Learn #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions sklearn/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def parse_args(parser, size=None, loop_types=(),
help='Seed to pass as random_state')
parser.add_argument('--dataset-name', type=str, default=None,
help='Dataset name')
parser.add_argument('--device', type=str, default='None',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this params should be enabled in higher level scripts that will execute them - but i don't see this in make?

Another question - are we going to pass single value for bench.py or this should be tuple so bench.py will iterate via them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this moment I used only runner.py script. Possible that additional changes should be done in make or somewhere else. Thanks.

Devices are provided as a list in config "device": ["None", "host", "cpu", "gpu"],. See example above. Benchmarks are executing with all devices in the list.

choices=('None', 'host', 'cpu', 'gpu'),
help='Execution context device, "None" to run without context.')
parser.add_argument('--patch_sklearn', type=str, default='None',
choices=('None', 'True', 'False'),
help='True for patch, False for unpatch, "None" to leave as is.')

for data in ['X', 'y']:
for stage in ['train', 'test']:
Expand Down Expand Up @@ -618,3 +624,36 @@ def import_fptype_getter():
except:
from daal4py.sklearn.utils import getFPType
return getFPType


def patch_sklearn():
parser = argparse.ArgumentParser()
parser.add_argument('--patch_sklearn', type=str, default='None',
choices=('None', 'True', 'False'),
help='True for patch, False for unpatch, "None" to leave as is.')
args, _ = parser.parse_known_args()

if args.patch_sklearn is not None and args.patch_sklearn != 'None':
from daal4py.sklearn import patch_sklearn, unpatch_sklearn
if args.patch_sklearn == "True":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are "True", "False" string, not boolean?

Copy link
Contributor Author

@Alexander-Makaryev Alexander-Makaryev May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more option is "None" - to leave patching state "as is". I think it can be useful for back compatibility. Possible we should change it to boolean.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be broken in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not common to have to write --patch-sklearn True. It might be better to instead have two possible options: --patch-sklearn setting patch_sklearn to True, --no-patch-sklearn setting patch_sklearn to False, and then None by default. This can be done by adding two arguments with the same dest='patch_sklearn', but with different actions store_true and store_false, and the same default None.
You may also want to put both of these arguments into a mutually exclusive group.

For the --device option, I would keep it the way you have it right now, but just let the default be None (not the string), and not allow specifying None (unless you call it something like auto). --device None is a bit confusing, but reading --device cpu or --device auto makes much more sense

If possible, it would be nice to not have the exact same construction of argument parsing in two places for both the device and for the patching. What if you added the same tri-state patch argument as a kwarg to patch_sklearn, and the device as a kwarg to run_with_context?

patch_sklearn()
elif args.patch_sklearn == "False":
unpatch_sklearn()
else:
raise ValueError('Parameter "patch_sklearn" must be '
'"None", "True" or "False", got {}.'.format(args.patch_sklearn))


def run_with_context(function):
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='None',
choices=('None', 'host', 'cpu', 'gpu'),
help='Execution context device, "None" to run without context.')
args, _ = parser.parse_known_args()

if args.device is not None and args.device != 'None':
from daal4py.oneapi import sycl_context
with sycl_context(args.device):
function()
else:
function()
86 changes: 47 additions & 39 deletions sklearn/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,50 @@
#
# SPDX-License-Identifier: MIT

import argparse
from bench import measure_function_time, parse_args, load_data, print_output
from sklearn.cluster import DBSCAN
from sklearn.metrics.cluster import davies_bouldin_score

parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10.,
help='Radius of neighborhood of a point')
parser.add_argument('-m', '--min-samples', default=5, type=int,
help='The minimum number of samples required in a '
'neighborhood to consider a point a core point')
params = parse_args(parser, n_jobs_supported=True)

# Load generated data
X, _, _, _ = load_data(params, add_dtype=True)

# Create our clustering object
dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs,
min_samples=params.min_samples, metric='euclidean',
algorithm='auto')

# N.B. algorithm='auto' will select DAAL's brute force method when running
# daal4py-patched scikit-learn, and probably 'kdtree' when running unpatched
# scikit-learn.

columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'n_clusters', 'time')

# Time fit
time, _ = measure_function_time(dbscan.fit, X, params=params)
labels = dbscan.labels_

params.n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
acc = davies_bouldin_score(X, labels)

print_output(library='sklearn', algorithm='dbscan', stages=['training'],
columns=columns, params=params, functions=['DBSCAN'],
times=[time], accuracies=[acc], accuracy_type='davies_bouldin_score', data=[X],
alg_instance=dbscan)
from bench import (measure_function_time, parse_args, load_data, print_output,
run_with_context, patch_sklearn)

def main():
import argparse
from sklearn.cluster import DBSCAN
from sklearn.metrics.cluster import davies_bouldin_score

parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10.,
help='Radius of neighborhood of a point')
parser.add_argument('-m', '--min-samples', default=5, type=int,
help='The minimum number of samples required in a '
'neighborhood to consider a point a core point')
params = parse_args(parser, n_jobs_supported=True)

# Load generated data
X, _, _, _ = load_data(params, add_dtype=True)

# Create our clustering object
dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs,
min_samples=params.min_samples, metric='euclidean',
algorithm='auto')

# N.B. algorithm='auto' will select DAAL's brute force method when running
# daal4py-patched scikit-learn, and probably 'kdtree' when running unpatched
# scikit-learn.

columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'n_clusters', 'time')

# Time fit
time, _ = measure_function_time(dbscan.fit, X, params=params)
labels = dbscan.labels_

params.n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
acc = davies_bouldin_score(X, labels)

print_output(library='sklearn', algorithm='dbscan', stages=['training'],
columns=columns, params=params, functions=['DBSCAN'],
times=[time], accuracies=[acc], accuracy_type='davies_bouldin_score', data=[X],
alg_instance=dbscan)


if __name__ == "__main__":
patch_sklearn()
run_with_context(main)
123 changes: 65 additions & 58 deletions sklearn/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,78 @@
# Copyright (C) 2017-2020 Intel Corporation
# Copyright (C) 2018-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import argparse
from bench import (
parse_args, measure_function_time, load_data, print_output
)
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import davies_bouldin_score

parser = argparse.ArgumentParser(description='scikit-learn K-means benchmark')
parser.add_argument('-i', '--filei', '--fileI', '--init',
type=str, help='Initial clusters')
parser.add_argument('-t', '--tol', type=float, default=0.,
help='Absolute threshold')
parser.add_argument('--maxiter', type=int, default=100,
help='Maximum number of iterations')
parser.add_argument('--n-clusters', type=int, help='Number of clusters')
params = parse_args(parser)

# Load and convert generated data
X_train, X_test, _, _ = load_data(params)

if params.filei == 'k-means++':
X_init = 'k-means++'
# Load initial centroids from specified path
elif params.filei is not None:
X_init = np.load(params.filei).astype(params.dtype)
params.n_clusters = X_init.shape[0]
# or choose random centroids from training data
else:
np.random.seed(params.seed)
centroids_idx = np.random.randint(0, X_train.shape[0],
size=params.n_clusters)
if hasattr(X_train, "iloc"):
X_init = X_train.iloc[centroids_idx].values
else:
X_init = X_train[centroids_idx]
from bench import (measure_function_time, parse_args, load_data, print_output,
run_with_context, patch_sklearn)

def main():
import argparse
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import davies_bouldin_score

def fit_kmeans(X):
global X_init, params
alg = KMeans(n_clusters=params.n_clusters, tol=params.tol,
max_iter=params.maxiter, init=X_init, n_init=1)
alg.fit(X)
return alg

parser = argparse.ArgumentParser(description='scikit-learn K-means benchmark')
parser.add_argument('-i', '--filei', '--fileI', '--init',
type=str, help='Initial clusters')
parser.add_argument('-t', '--tol', type=float, default=0.,
help='Absolute threshold')
parser.add_argument('--maxiter', type=int, default=100,
help='Maximum number of iterations')
parser.add_argument('--n-clusters', type=int, help='Number of clusters')
params = parse_args(parser)

# Load and convert generated data
X_train, X_test, _, _ = load_data(params)

if params.filei == 'k-means++':
X_init = 'k-means++'
# Load initial centroids from specified path
elif params.filei is not None:
X_init = np.load(params.filei).astype(params.dtype)
params.n_clusters = X_init.shape[0]
# or choose random centroids from training data
else:
np.random.seed(params.seed)
centroids_idx = np.random.randint(0, X_train.shape[0],
size=params.n_clusters)
if hasattr(X_train, "iloc"):
X_init = X_train.iloc[centroids_idx].values
else:
X_init = X_train[centroids_idx]


def fit_kmeans(X):
global X_init, params
alg = KMeans(n_clusters=params.n_clusters, tol=params.tol,
max_iter=params.maxiter, init=X_init, n_init=1)
alg.fit(X)
return alg


columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'n_clusters', 'time')

columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'n_clusters', 'time')
# Time fit
fit_time, kmeans = measure_function_time(fit_kmeans, X_train, params=params)

# Time fit
fit_time, kmeans = measure_function_time(fit_kmeans, X_train, params=params)
train_predict = kmeans.predict(X_train)
acc_train = davies_bouldin_score(X_train, train_predict)

train_predict = kmeans.predict(X_train)
acc_train = davies_bouldin_score(X_train, train_predict)
# Time predict
predict_time, test_predict = measure_function_time(
kmeans.predict, X_test, params=params)

# Time predict
predict_time, test_predict = measure_function_time(
kmeans.predict, X_test, params=params)
acc_test = davies_bouldin_score(X_test, test_predict)

acc_test = davies_bouldin_score(X_test, test_predict)
print_output(library='sklearn', algorithm='kmeans',
stages=['training', 'prediction'], columns=columns,
params=params, functions=['KMeans.fit', 'KMeans.predict'],
times=[fit_time, predict_time], accuracy_type='davies_bouldin_score',
accuracies=[acc_train, acc_test], data=[X_train, X_test],
alg_instance=kmeans)

print_output(library='sklearn', algorithm='kmeans',
stages=['training', 'prediction'], columns=columns,
params=params, functions=['KMeans.fit', 'KMeans.predict'],
times=[fit_time, predict_time], accuracy_type='davies_bouldin_score',
accuracies=[acc_train, acc_test], data=[X_train, X_test],
alg_instance=kmeans)
if __name__ == "__main__":
patch_sklearn()
run_with_context(main)
87 changes: 47 additions & 40 deletions sklearn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,50 @@
#
# SPDX-License-Identifier: MIT

import argparse
from bench import (
parse_args, measure_function_time, load_data, print_output, rmse_score
)
from sklearn.linear_model import LinearRegression

parser = argparse.ArgumentParser(description='scikit-learn linear regression '
'benchmark')
parser.add_argument('--no-fit-intercept', dest='fit_intercept', default=True,
action='store_false',
help="Don't fit intercept (assume data already centered)")
params = parse_args(parser, size=(1000000, 50))

# Load data
X_train, X_test, y_train, y_test = load_data(
params, generated_data=['X_train', 'y_train'])

# Create our regression object
regr = LinearRegression(fit_intercept=params.fit_intercept,
n_jobs=params.n_jobs, copy_X=False)

columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'time')

# Time fit
fit_time, _ = measure_function_time(regr.fit, X_train, y_train, params=params)

# Time predict
predict_time, yp = measure_function_time(regr.predict, X_test, params=params)

test_rmse = rmse_score(yp, y_test)
yp = regr.predict(X_train)
train_rmse = rmse_score(yp, y_train)

print_output(library='sklearn', algorithm='linear_regression',
stages=['training', 'prediction'], columns=columns,
params=params, functions=['Linear.fit', 'Linear.predict'],
times=[fit_time, predict_time], accuracy_type='rmse',
accuracies=[train_rmse, test_rmse], data=[X_train, X_test],
alg_instance=regr)
from bench import (measure_function_time, parse_args, load_data, print_output, rmse_score,
run_with_context, patch_sklearn)


def main():
import argparse
from sklearn.linear_model import LinearRegression

parser = argparse.ArgumentParser(description='scikit-learn linear regression '
'benchmark')
parser.add_argument('--no-fit-intercept', dest='fit_intercept', default=True,
action='store_false',
help="Don't fit intercept (assume data already centered)")
params = parse_args(parser, size=(1000000, 50))

# Load data
X_train, X_test, y_train, y_test = load_data(
params, generated_data=['X_train', 'y_train'])

# Create our regression object
regr = LinearRegression(fit_intercept=params.fit_intercept,
n_jobs=params.n_jobs, copy_X=False)

columns = ('batch', 'arch', 'prefix', 'function', 'threads', 'dtype', 'size',
'time')

# Time fit
fit_time, _ = measure_function_time(regr.fit, X_train, y_train, params=params)

# Time predict
predict_time, yp = measure_function_time(regr.predict, X_test, params=params)

test_rmse = rmse_score(yp, y_test)
yp = regr.predict(X_train)
train_rmse = rmse_score(yp, y_train)

print_output(library='sklearn', algorithm='linear_regression',
stages=['training', 'prediction'], columns=columns,
params=params, functions=['Linear.fit', 'Linear.predict'],
times=[fit_time, predict_time], accuracy_type='rmse',
accuracies=[train_rmse, test_rmse], data=[X_train, X_test],
alg_instance=regr)


if __name__ == "__main__":
patch_sklearn()
run_with_context(main)
Loading