Skip to content

Commit e318f64

Browse files
committed
Add dpnp and dpctl support
1 parent 1c6bd66 commit e318f64

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_subset_metrics_of_estimator(
121121
metrics = dict()
122122
# Note: use data[0, 1] when calling estimator methods,
123123
# x, y are numpy ndarrays for compatibility with sklearn metrics
124-
x, y = list(map(convert_to_numpy, data))
124+
x, y = list(map(lambda i: convert_to_numpy(i, dp_compat=True), data))
125125
if stage == "training":
126126
if hasattr(estimator_instance, "n_iter_"):
127127
iterations = estimator_instance.n_iter_

sklbench/datasets/transformer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..utils.logger import logger
2626

2727

28-
def convert_data(data, dformat: str, order: str, dtype: str):
28+
def convert_data(data, dformat: str, order: str, dtype: str, device: str = None):
2929
if isinstance(data, csr_matrix) and dformat != "csr_matrix":
3030
data = data.toarray()
3131
if dtype == "preserve":
@@ -42,6 +42,14 @@ def convert_data(data, dformat: str, order: str, dtype: str):
4242
if data.ndim == 1:
4343
return pd.Series(data)
4444
return pd.DataFrame(data)
45+
elif dformat == "dpnp":
46+
import dpnp
47+
48+
return dpnp.array(data, dtype=dtype, order=order, device=device)
49+
elif dformat == "dpctl":
50+
import dpctl.tensor
51+
52+
return dpctl.tensor.asarray(data, dtype=dtype, order=order, device=device)
4553
elif dformat.startswith("modin"):
4654
if dformat.endswith("ray"):
4755
os.environ["MODIN_ENGINE"] = "ray"
@@ -100,6 +108,7 @@ def split_and_transform_data(bench_case, data, data_description):
100108
x_train, x_test = train_test_split_wrapper(x, **split_kwargs)
101109
y_train, y_test = None, None
102110

111+
device = get_bench_case_value(bench_case, "algorithm:device", None)
103112
common_data_format = get_bench_case_value(bench_case, "data:format", "pandas")
104113
common_data_order = get_bench_case_value(bench_case, "data:order", "F")
105114
common_data_dtype = get_bench_case_value(bench_case, "data:dtype", "float64")
@@ -134,7 +143,9 @@ def split_and_transform_data(bench_case, data, data_description):
134143
if is_label and required_label_dtype is not None:
135144
data_dtype = required_label_dtype
136145

137-
converted_data = convert_data(subset_content, data_format, data_order, data_dtype)
146+
converted_data = convert_data(
147+
subset_content, data_format, data_order, data_dtype, device
148+
)
138149
data_dict[subset_name] = converted_data
139150
if not is_label:
140151
data_description[subset_name] = {

sklbench/utils/common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,23 @@ def convert_to_numeric_if_possible(value: str) -> Union[Numeric, str]:
193193
return value
194194

195195

196-
def convert_to_numpy(a):
196+
def convert_to_numpy(a, dp_compat=False):
197+
if dp_compat and ("dpctl" in str(type(a)) or "dpnp" in str(type(a))):
198+
return a
197199
if isinstance(a, np.ndarray):
198200
return a
199201
elif hasattr(a, "to_numpy"):
200202
return a.to_numpy()
201203
elif hasattr(a, "asnumpy"):
202204
return a.asnumpy()
205+
elif "dpnp" in str(type(a)):
206+
import dpnp
207+
208+
return dpnp.asnumpy(a)
209+
elif "dpctl" in str(type(a)):
210+
import dpctl.tensor
211+
212+
return dpctl.tensor.to_numpy(a)
203213
elif "cupy.ndarray" in str(type(a)):
204214
return a.get()
205215
else:

0 commit comments

Comments
 (0)