Skip to content

Add device context parameter #57

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ def parse_args(parser, size=None, loop_types=(),
help='Dataset name')
parser.add_argument('--no-intel-optimized', default=False, action='store_true',
help='Use no intel optimized version. '
'Now avalible for scikit-learn benchmarks'),
'Now avalible for scikit-learn benchmarks')
parser.add_argument('--device', default=None, type=str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
parser.add_argument('--device', default=None, type=str,
parser.add_argument('--device', default="host", type=str,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that None is used to run without context. Other values specify device type for a context

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh, then I think the host is not needed at all

choices=("host", "cpu", "gpu"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None shall be also included?

help='Execution context device')

for data in ['X', 'y']:
for stage in ['train', 'test']:
parser.add_argument(f'--file-{data}-{stage}',
Expand All @@ -197,6 +201,8 @@ def parse_args(parser, size=None, loop_types=(),
except ImportError:
print('Failed to import daal4py.sklearn.patch_sklearn.'
'Use stock version scikit-learn', file=sys.stderr)
else:
params.device = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check if the device parameter is passed by the user and print a warning that it is useless in that case - for clarity


# disable finiteness check (default)
if not params.check_finiteness:
Expand Down Expand Up @@ -492,3 +498,11 @@ def print_output(library, algorithm, stages, params, functions,
del result['algorithm_parameters']['handle']
output.append(result)
print(json.dumps(output, indent=4))

def run_with_context(params, function):
if params.device is not None:
from daal4py.oneapi import sycl_context
with sycl_context(params.device):
function()
else:
function()
77 changes: 77 additions & 0 deletions configs/skl_with_context_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"common": {
"lib": ["sklearn"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How stock or intel version of sk is specified for this config?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to launch the stock sk in this config? Maybe just add flag no-intel-optimized before config run? I think it's better to separate patched and unpatched sk launches

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - probably its better to skip device != None branches for stock sklearn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add support to skip these cases?

"data-format": ["pandas"],
"data-order": ["F"],
"dtype": ["float64"],
"device": ["host", "cpu", "gpu"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if I run this config on a machine without a GPU driver?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly the same what happens if you try to run on a CPU wo DPC++ support - an exception. What is your suggession here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shall point somewhere that using this config file requires DPC++ support and GPU device on board

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, probably it's ok

},
"cases": [
{
"algorithm": "kmeans",
"dataset": [
{
"source": "synthetic",
"type": "blobs",
"n_clusters": 10,
"n_features": 50,
"training": {
"n_samples": 1000000
}
}
],
"n-clusters": [10]
},
{
"algorithm": "dbscan",
"dataset": [
{
"source": "synthetic",
"type": "blobs",
"n_clusters": 10,
"n_features": 50,
"training": {
"n_samples": 10000
}
}
]
},
{
"algorithm": "linear",
"dataset": [
{
"source": "synthetic",
"type": "regression",
"n_features": 50,
"training": {
"n_samples": 1000000
}
}
]
},
{
"algorithm": "log_reg",
"solver":["lbfgs", "newton-cg"],
"dataset": [
{
"source": "synthetic",
"type": "classification",
"n_classes": 2,
"n_features": 100,
"training": {
"n_samples": 100000
}
},
{
"source": "synthetic",
"type": "classification",
"n_classes": 5,
"n_features": 100,
"training": {
"n_samples": 100000
}
}
]
}
]
}
3 changes: 3 additions & 0 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def generate_cases(params):
parser.add_argument('--report', default=False, action='store_true',
help='Create an Excel report based on benchmarks results. '
'Need "openpyxl" library')
parser.add_argument('--device', default=None, type=str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the parameter is duplicated in bench.py and runner.py?

choices=("host", "cpu", "gpu"),
help='Execution context device')
args = parser.parse_args()
env = os.environ.copy()

Expand Down
71 changes: 37 additions & 34 deletions sklearn_bench/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,41 @@
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import bench
from sklearn.metrics.cluster import davies_bouldin_score

parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10.,
help='Radius of neighborhood of a point')
parser.add_argument('-m', '--min-samples', default=5, type=int,
help='The minimum number of samples required in a '
'neighborhood to consider a point a core point')
params = bench.parse_args(parser)

from sklearn.cluster import DBSCAN

# Load generated data
X, _, _, _ = bench.load_data(params, add_dtype=True)

# Create our clustering object
dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs,
min_samples=params.min_samples, metric='euclidean',
algorithm='auto')

# N.B. algorithm='auto' will select DAAL's brute force method when running
# daal4py-patched scikit-learn, and probably 'kdtree' when running unpatched
# scikit-learn.

# Time fit
time, _ = bench.measure_function_time(dbscan.fit, X, params=params)
labels = dbscan.labels_

params.n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
acc = davies_bouldin_score(X, labels)

bench.print_output(library='sklearn', algorithm='dbscan', stages=['training'],
params=params, functions=['DBSCAN'], times=[time], accuracies=[acc],
accuracy_type='davies_bouldin_score', data=[X],
alg_instance=dbscan)
def main():
from sklearn.cluster import DBSCAN
from sklearn.metrics.cluster import davies_bouldin_score

# Load generated data
X, _, _, _ = bench.load_data(params, add_dtype=True)

# Create our clustering object
dbscan = DBSCAN(eps=params.eps, n_jobs=params.n_jobs,
min_samples=params.min_samples, metric='euclidean',
algorithm='auto')

# N.B. algorithm='auto' will select DAAL's brute force method when running
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# N.B. algorithm='auto' will select DAAL's brute force method when running
# N.B. algorithm='auto' will select oneAPI Data Analytics Library (oneDAL) brute force method when running

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PetrovKP what about other files that @vlad-nazarov did not touch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will correct if there are such places yet

# daal4py-patched scikit-learn, and probably 'kdtree' when running unpatched
# scikit-learn.

# Time fit
time, _ = bench.measure_function_time(dbscan.fit, X, params=params)
labels = dbscan.labels_

params.n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
acc = davies_bouldin_score(X, labels)

bench.print_output(library='sklearn', algorithm='dbscan', stages=['training'],
params=params, functions=['DBSCAN'], times=[time], accuracies=[acc],
accuracy_type='davies_bouldin_score', data=[X],
alg_instance=dbscan)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10.,
help='Radius of neighborhood of a point')
parser.add_argument('-m', '--min-samples', default=5, type=int,
help='The minimum number of samples required in a '
'neighborhood to consider a point a core point')
params = bench.parse_args(parser)
bench.run_with_context(params, main)
98 changes: 50 additions & 48 deletions sklearn_bench/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,65 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import bench
import numpy as np
from sklearn.metrics.cluster import davies_bouldin_score

parser = argparse.ArgumentParser(description='scikit-learn K-means benchmark')
parser.add_argument('-i', '--filei', '--fileI', '--init',
type=str, help='Initial clusters')
parser.add_argument('-t', '--tol', type=float, default=0.,
help='Absolute threshold')
parser.add_argument('--maxiter', type=int, default=100,
help='Maximum number of iterations')
parser.add_argument('--n-clusters', type=int, help='Number of clusters')
params = bench.parse_args(parser)
def main():
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import davies_bouldin_score

from sklearn.cluster import KMeans
# Load and convert generated data
X_train, X_test, _, _ = bench.load_data(params)

# Load and convert generated data
X_train, X_test, _, _ = bench.load_data(params)

if params.filei == 'k-means++':
X_init = 'k-means++'
# Load initial centroids from specified path
elif params.filei is not None:
X_init = np.load(params.filei).astype(params.dtype)
params.n_clusters = X_init.shape[0]
# or choose random centroids from training data
else:
np.random.seed(params.seed)
centroids_idx = np.random.randint(0, X_train.shape[0],
size=params.n_clusters)
if hasattr(X_train, "iloc"):
X_init = X_train.iloc[centroids_idx].values
if params.filei == 'k-means++':
X_init = 'k-means++'
# Load initial centroids from specified path
elif params.filei is not None:
X_init = np.load(params.filei).astype(params.dtype)
params.n_clusters = X_init.shape[0]
# or choose random centroids from training data
else:
X_init = X_train[centroids_idx]
np.random.seed(params.seed)
centroids_idx = np.random.randint(0, X_train.shape[0],
size=params.n_clusters)
if hasattr(X_train, "iloc"):
X_init = X_train.iloc[centroids_idx].values
else:
X_init = X_train[centroids_idx]


def fit_kmeans(X):
global X_init, params
alg = KMeans(n_clusters=params.n_clusters, tol=params.tol,
max_iter=params.maxiter, init=X_init, n_init=1)
alg.fit(X)
return alg
def fit_kmeans(X, X_init):
alg = KMeans(n_clusters=params.n_clusters, tol=params.tol,
max_iter=params.maxiter, init=X_init, n_init=1)
alg.fit(X)
return alg

# Time fit
fit_time, kmeans = bench.measure_function_time(fit_kmeans, X_train, X_init, params=params)

# Time fit
fit_time, kmeans = bench.measure_function_time(fit_kmeans, X_train, params=params)
train_predict = kmeans.predict(X_train)
acc_train = davies_bouldin_score(X_train, train_predict)

train_predict = kmeans.predict(X_train)
acc_train = davies_bouldin_score(X_train, train_predict)
# Time predict
predict_time, test_predict = bench.measure_function_time(
kmeans.predict, X_test, params=params)

# Time predict
predict_time, test_predict = bench.measure_function_time(
kmeans.predict, X_test, params=params)
acc_test = davies_bouldin_score(X_test, test_predict)

acc_test = davies_bouldin_score(X_test, test_predict)
bench.print_output(library='sklearn', algorithm='kmeans',
stages=['training', 'prediction'],
params=params, functions=['KMeans.fit', 'KMeans.predict'],
times=[fit_time, predict_time], accuracy_type='davies_bouldin_score',
accuracies=[acc_train, acc_test], data=[X_train, X_test],
alg_instance=kmeans)

bench.print_output(library='sklearn', algorithm='kmeans',
stages=['training', 'prediction'],
params=params, functions=['KMeans.fit', 'KMeans.predict'],
times=[fit_time, predict_time], accuracy_type='davies_bouldin_score',
accuracies=[acc_train, acc_test], data=[X_train, X_test],
alg_instance=kmeans)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='scikit-learn K-means benchmark')
parser.add_argument('-i', '--filei', '--fileI', '--init',
type=str, help='Initial clusters')
parser.add_argument('-t', '--tol', type=float, default=0.,
help='Absolute threshold')
parser.add_argument('--maxiter', type=int, default=100,
help='Maximum number of iterations')
parser.add_argument('--n-clusters', type=int, help='Number of clusters')
params = bench.parse_args(parser)
bench.run_with_context(params, main)

55 changes: 29 additions & 26 deletions sklearn_bench/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,39 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import bench

parser = argparse.ArgumentParser(description='scikit-learn linear regression '
'benchmark')
parser.add_argument('--no-fit-intercept', dest='fit_intercept', default=True,
action='store_false',
help="Don't fit intercept (assume data already centered)")
params = bench.parse_args(parser)
def main():
from sklearn.linear_model import LinearRegression

from sklearn.linear_model import LinearRegression
# Load data
X_train, X_test, y_train, y_test = bench.load_data(
params, generated_data=['X_train', 'y_train'])

# Load data
X_train, X_test, y_train, y_test = bench.load_data(
params, generated_data=['X_train', 'y_train'])
# Create our regression object
regr = LinearRegression(fit_intercept=params.fit_intercept,
n_jobs=params.n_jobs, copy_X=False)

# Create our regression object
regr = LinearRegression(fit_intercept=params.fit_intercept,
n_jobs=params.n_jobs, copy_X=False)
# Time fit
fit_time, _ = bench.measure_function_time(regr.fit, X_train, y_train, params=params)

# Time fit
fit_time, _ = bench.measure_function_time(regr.fit, X_train, y_train, params=params)
# Time predict
predict_time, yp = bench.measure_function_time(regr.predict, X_test, params=params)

# Time predict
predict_time, yp = bench.measure_function_time(regr.predict, X_test, params=params)
test_rmse = bench.rmse_score(yp, y_test)
yp = regr.predict(X_train)
train_rmse = bench.rmse_score(yp, y_train)

test_rmse = bench.rmse_score(yp, y_test)
yp = regr.predict(X_train)
train_rmse = bench.rmse_score(yp, y_train)
bench.print_output(library='sklearn', algorithm='linear_regression',
stages=['training', 'prediction'],
params=params, functions=['Linear.fit', 'Linear.predict'],
times=[fit_time, predict_time], accuracy_type='rmse',
accuracies=[train_rmse, test_rmse], data=[X_train, X_test],
alg_instance=regr)

bench.print_output(library='sklearn', algorithm='linear_regression',
stages=['training', 'prediction'],
params=params, functions=['Linear.fit', 'Linear.predict'],
times=[fit_time, predict_time], accuracy_type='rmse',
accuracies=[train_rmse, test_rmse], data=[X_train, X_test],
alg_instance=regr)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='scikit-learn linear regression '
'benchmark')
parser.add_argument('--no-fit-intercept', dest='fit_intercept', default=True,
action='store_false',
help="Don't fit intercept (assume data already centered)")
params = bench.parse_args(parser)
bench.run_with_context(params, main)
Loading