Skip to content

Extend output result & minor fixes #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def get_dtype(data):
'''
if hasattr(data, 'dtype'):
return data.dtype
elif hasattr(data, 'dtypes'):
if hasattr(data, 'dtypes'):
return str(data.dtypes[0])
elif hasattr(data, 'values'):
if hasattr(data, 'values'):
return data.values.dtype
else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just raise an exception here, without else

raise ValueError(f'Impossible to get data type of {type(data)}')
Expand Down Expand Up @@ -66,10 +66,7 @@ def _parse_size(string, dim=2):


def float_or_int(string):
if '.' in string:
return float(string)
else:
return int(string)
return float(string) if '.' in string else int(string)


def get_optimal_cache_size(n_rows, dtype=np.double, max_cache=64):
Expand All @@ -90,10 +87,8 @@ def get_optimal_cache_size(n_rows, dtype=np.double, max_cache=64):
optimal_cache_size_bytes = byte_size * (n_rows ** 2)
one_gb = 2 ** 30
max_cache_bytes = max_cache * one_gb
if optimal_cache_size_bytes > max_cache_bytes:
return max_cache_bytes
else:
return optimal_cache_size_bytes
return max_cache_bytes if optimal_cache_size_bytes > max_cache_bytes \
else optimal_cache_size_bytes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An extra comma at the end of line 175 'Now avalible for scikit-learn benchmarks'),



def parse_args(parser, size=None, loop_types=(),
Expand Down Expand Up @@ -324,34 +319,47 @@ def convert_to_numpy(data):
return data


def columnwise_score(y, yp, score_func):
def accuracy_score(y, yp):
from sklearn.metrics import accuracy_score as sklearn_accuracy
y = convert_to_numpy(y)
yp = convert_to_numpy(yp)
if y.ndim + yp.ndim > 2:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that code really useless? does sklearn_accuracy work same?

if 1 in (y.shape + yp.shape)[1:]:
if y.ndim > 1:
y = y[:, 0]
if yp.ndim > 1:
yp = yp[:, 0]
else:
return [score_func(y[i], yp[i]) for i in range(y.shape[1])]
return score_func(y, yp)


def accuracy_score(y, yp):
return columnwise_score(y, yp, lambda y1, y2: np.mean(y1 == y2))
return sklearn_accuracy(y, yp)


def log_loss(y, yp):
from sklearn.metrics import log_loss as sklearn_log_loss
y = convert_to_numpy(y)
yp = convert_to_numpy(yp)
return sklearn_log_loss(y, yp)
try:
res = sklearn_log_loss(y, yp)
except Exception:
res = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling exceptions like this is a bad practice: every mistake will be hidden by this code
Is log_loss always properly called?
Is it always given probabilities and not labels?
Is the argument order right?
Same questions should be answered for other metrics

return res


def roc_auc_score(y, yp, multi_class='ovr'):
from sklearn.metrics import roc_auc_score as sklearn_roc_auc
y = convert_to_numpy(y)
yp = convert_to_numpy(yp)
try:
res = sklearn_roc_auc(y, yp, multi_class=multi_class)
except Exception:
res = None
return res


def rmse_score(y, yp):
return columnwise_score(
y, yp, lambda y1, y2: float(np.sqrt(np.mean((y1 - y2)**2))))
from sklearn.metrics import mean_squared_error as sklearn_mse
y = convert_to_numpy(y)
yp = convert_to_numpy(yp)
return sklearn_mse(y, yp)


def r2_score(y, yp):
from sklearn.metrics import r2_score as sklearn_r2_score
y = convert_to_numpy(y)
yp = convert_to_numpy(yp)
return sklearn_r2_score(y, yp)


def convert_data(data, dtype, data_order, data_format):
Expand All @@ -367,14 +375,11 @@ def convert_data(data, dtype, data_order, data_format):
# Secondly, change format of data
if data_format == 'numpy':
return data
elif data_format == 'pandas':
if data_format == 'pandas':
import pandas as pd

if data.ndim == 1:
return pd.Series(data)
else:
return pd.DataFrame(data)
elif data_format == 'cudf':
return pd.Series(data) if data.ndim == 1 else pd.DataFrame(data)
if data_format == 'cudf':
import cudf
import pandas as pd

Expand Down Expand Up @@ -497,7 +502,12 @@ def print_output(library, algorithm, stages, params, functions,
data[i], alg_instance, alg_params)
result.update({'time[s]': times[i]})
if accuracy_type is not None:
result.update({f'{accuracy_type}': accuracies[i]})
if isinstance(accuracy_type, str):
result.update({f'{accuracy_type}': accuracies[i]})
elif isinstance(accuracy_type, list):
for ind, val in enumerate(accuracy_type):
if accuracies[ind][i] is not None:
result.update({f'{val}': accuracies[ind][i]})
if hasattr(params, 'n_classes'):
result['input_data'].update({'classes': params.n_classes})
if hasattr(params, 'n_clusters'):
Expand Down
2 changes: 1 addition & 1 deletion configs/blogs/skl_2021_3.json
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@
}
],
"nu": [0.25],
"kernel": ["sigmoid"]
"kernel": ["poly"]
},
{
"algorithm": "svr",
Expand Down
41 changes: 26 additions & 15 deletions datasets/loader_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def airline(dataset_dir: Path) -> bool:
Airline dataset
http://kt.ijs.si/elena_ikonomovska/data.html

TaskType:binclass
NumberOfFeatures:13
NumberOfInstances:115M
Classification task. n_classes = 2.
airline X train dataset (92055213, 13)
airline y train dataset (92055213, 1)
airline X test dataset (23013804, 13)
airline y test dataset (23013804, 1)
"""
dataset_name = 'airline'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -126,9 +128,12 @@ def airline(dataset_dir: Path) -> bool:
def airline_ohe(dataset_dir: Path) -> bool:
"""
Dataset from szilard benchmarks: https://github.com/szilard/GBM-perf
TaskType:binclass
NumberOfFeatures:700
NumberOfInstances:10100000

Classification task. n_classes = 2.
airline-ohe X train dataset (1000000, 692)
airline-ohe y train dataset (1000000, 1)
airline-ohe X test dataset (100000, 692)
airline-ohe y test dataset (100000, 1)
"""
dataset_name = 'airline-ohe'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -289,9 +294,11 @@ def epsilon(dataset_dir: Path) -> bool:
Epsilon dataset
https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html

TaskType:binclass
NumberOfFeatures:2000
NumberOfInstances:500K
Classification task. n_classes = 2.
epsilon X train dataset (400000, 2000)
epsilon y train dataset (400000, 1)
epsilon X test dataset (100000, 2000)
epsilon y test dataset (100000, 1)
"""
dataset_name = 'epsilon'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -444,9 +451,11 @@ def higgs(dataset_dir: Path) -> bool:
Higgs dataset from UCI machine learning repository
https://archive.ics.uci.edu/ml/datasets/HIGGS

TaskType:binclass
NumberOfFeatures:28
NumberOfInstances:11M
Classification task. n_classes = 2.
higgs X train dataset (8799999, 28)
higgs y train dataset (8799999, 1)
higgs X test dataset (2200000, 28)
higgs y test dataset (2200000, 1)
"""
dataset_name = 'higgs'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -479,9 +488,11 @@ def higgs_one_m(dataset_dir: Path) -> bool:

Only first 1.5M samples is taken

TaskType:binclass
NumberOfFeatures:28
NumberOfInstances:1.5M
Classification task. n_classes = 2.
higgs1m X train dataset (1000000, 28)
higgs1m y train dataset (1000000, 1)
higgs1m X test dataset (500000, 28)
higgs1m y test dataset (500000, 1)
"""
dataset_name = 'higgs1m'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down
26 changes: 16 additions & 10 deletions datasets/loader_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def covtype(dataset_dir: Path) -> bool:
https://archive.ics.uci.edu/ml/datasets/covertype

y contains 7 unique class labels from 1 to 7 inclusive.
TaskType:multiclass
NumberOfFeatures:54
NumberOfInstances:581012
Classification task. n_classes = 7.
covtype X train dataset (464809, 54)
covtype y train dataset (464809, 1)
covtype X test dataset (116203, 54)
covtype y test dataset (116203, 1)
"""
dataset_name = 'covtype'
os.makedirs(dataset_dir, exist_ok=True)
Expand All @@ -125,9 +127,11 @@ def letters(dataset_dir: Path) -> bool:
"""
http://archive.ics.uci.edu/ml/datasets/Letter+Recognition

TaskType:multiclass
NumberOfFeatures:16
NumberOfInstances:20.000
Classification task. n_classes = 26.
letters X train dataset (16000, 16)
letters y train dataset (16000, 1)
letters X test dataset (4000, 16)
letters y test dataset (4000, 1)
"""
dataset_name = 'letters'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -204,9 +208,11 @@ def msrank(dataset_dir: Path) -> bool:
"""
Dataset from szilard benchmarks: https://github.com/szilard/GBM-perf

TaskType:multiclass
NumberOfFeatures:137
NumberOfInstances:1.2M
Classification task. n_classes = 5.
msrank X train dataset (958671, 137)
msrank y train dataset (958671, 1)
msrank X test dataset (241521, 137)
msrank y test dataset (241521, 1)
"""
dataset_name = 'msrank'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -264,7 +270,7 @@ def sensit(dataset_dir: Path) -> bool:
Author: M. Duarte, Y. H. Hu
Source: [original](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets)

Multiclass classification task
Classification task. n_classes = 3.
sensit X train dataset (78822, 100)
sensit y train dataset (78822, 1)
sensit X test dataset (19706, 100)
Expand Down
14 changes: 8 additions & 6 deletions datasets/loader_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def abalone(dataset_dir: Path) -> bool:
"""
https://archive.ics.uci.edu/ml/machine-learning-databases/abalone

TaskType:regression
NumberOfFeatures:8
NumberOfInstances:4177
abalone x train dataset (3341, 8)
abalone y train dataset (3341, 1)
abalone x test dataset (836, 8)
abalone y train dataset (836, 1)
"""
dataset_name = 'abalone'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down Expand Up @@ -196,9 +197,10 @@ def year_prediction_msd(dataset_dir: Path) -> bool:
YearPredictionMSD dataset from UCI repository
https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd

TaskType:regression
NumberOfFeatures:90
NumberOfInstances:515345
year_prediction_msd x train dataset (463715, 90)
year_prediction_msd y train dataset (463715, 1)
year_prediction_msd x test dataset (51630, 90)
year_prediction_msd y train dataset (51630, 1)
"""
dataset_name = 'year_prediction_msd'
os.makedirs(dataset_dir, exist_ok=True)
Expand Down
7 changes: 5 additions & 2 deletions sklearn_bench/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def main():
labels = dbscan.labels_

params.n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
acc = davies_bouldin_score(X, labels)
try:
acc = davies_bouldin_score(X, labels)
except Exception:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to have all scoring functions in the same place (bench.py)
And again, such handling is a bad practice

acc = -1

bench.print_output(library='sklearn', algorithm='dbscan', stages=['training'],
params=params, functions=['DBSCAN'], times=[time],
Expand All @@ -50,7 +53,7 @@ def main():

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='scikit-learn DBSCAN benchmark')
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=10.,
parser.add_argument('-e', '--eps', '--epsilon', type=float, default=0.5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you change it? this will not affect the measurements of the current configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is default value in sklearn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will not affect the measurements of the current configs?

help='Radius of neighborhood of a point')
parser.add_argument('-m', '--min-samples', default=5, type=int,
help='The minimum number of samples required in a '
Expand Down
31 changes: 22 additions & 9 deletions sklearn_bench/df_clsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import bench
import numpy as np
from sklearn.metrics import accuracy_score


def main():
Expand All @@ -43,18 +42,32 @@ def main():

fit_time, _ = bench.measure_function_time(clf.fit, X_train, y_train, params=params)
y_pred = clf.predict(X_train)
train_acc = 100 * accuracy_score(y_pred, y_train)
train_acc = bench.accuracy_score(y_train, y_pred)
train_log_loss = bench.log_loss(y_train, y_pred)
train_roc_auc = bench.roc_auc_score(y_train, y_pred)

predict_time, y_pred = bench.measure_function_time(
clf.predict, X_test, params=params)
test_acc = 100 * accuracy_score(y_pred, y_test)
test_acc = bench.accuracy_score(y_test, y_pred)
test_log_loss = bench.log_loss(y_test, y_pred)
test_roc_auc = bench.roc_auc_score(y_test, y_pred)

bench.print_output(library='sklearn', algorithm='decision_forest_classification',
stages=['training', 'prediction'], params=params,
functions=['df_clsf.fit', 'df_clsf.predict'],
times=[fit_time, predict_time], accuracy_type='accuracy[%]',
accuracies=[train_acc, test_acc], data=[X_train, X_test],
alg_instance=clf)
bench.print_output(
library='sklearn',
algorithm='decision_forest_classification',
stages=['training', 'prediction'],
params=params,
functions=['df_clsf.fit', 'df_clsf.predict'],
times=[fit_time, predict_time],
accuracy_type=['accuracy', 'log_loss', 'roc_auc'],
accuracies=[
[train_acc, test_acc],
[train_log_loss, test_log_loss],
[train_roc_auc, test_roc_auc],
],
data=[X_train, X_test],
alg_instance=clf,
)


if __name__ == "__main__":
Expand Down
Loading