25
25
from ..utils .logger import logger
26
26
27
27
28
- def convert_data (data , dformat : str , order : str , dtype : str ):
28
+ def convert_data (data , dformat : str , order : str , dtype : str , device : str = None ):
29
29
if isinstance (data , csr_matrix ) and dformat != "csr_matrix" :
30
30
data = data .toarray ()
31
31
if dtype == "preserve" :
@@ -42,6 +42,14 @@ def convert_data(data, dformat: str, order: str, dtype: str):
42
42
if data .ndim == 1 :
43
43
return pd .Series (data )
44
44
return pd .DataFrame (data )
45
+ elif dformat == "dpnp" :
46
+ import dpnp
47
+
48
+ return dpnp .array (data , dtype = dtype , order = order , device = device )
49
+ elif dformat == "dpctl" :
50
+ import dpctl .tensor
51
+
52
+ return dpctl .tensor .asarray (data , dtype = dtype , order = order , device = device )
45
53
elif dformat .startswith ("modin" ):
46
54
if dformat .endswith ("ray" ):
47
55
os .environ ["MODIN_ENGINE" ] = "ray"
@@ -100,6 +108,7 @@ def split_and_transform_data(bench_case, data, data_description):
100
108
x_train , x_test = train_test_split_wrapper (x , ** split_kwargs )
101
109
y_train , y_test = None , None
102
110
111
+ device = get_bench_case_value (bench_case , "algorithm:device" , None )
103
112
common_data_format = get_bench_case_value (bench_case , "data:format" , "pandas" )
104
113
common_data_order = get_bench_case_value (bench_case , "data:order" , "F" )
105
114
common_data_dtype = get_bench_case_value (bench_case , "data:dtype" , "float64" )
@@ -134,7 +143,9 @@ def split_and_transform_data(bench_case, data, data_description):
134
143
if is_label and required_label_dtype is not None :
135
144
data_dtype = required_label_dtype
136
145
137
- converted_data = convert_data (subset_content , data_format , data_order , data_dtype )
146
+ converted_data = convert_data (
147
+ subset_content , data_format , data_order , data_dtype , device
148
+ )
138
149
data_dict [subset_name ] = converted_data
139
150
if not is_label :
140
151
data_description [subset_name ] = {
0 commit comments