38
38
)
39
39
40
40
from ..datasets import load_dataset
41
+ from ..datasets .special_params import assign_case_special_values_on_run
41
42
from ..datasets .transformer import split_and_transform_data
42
43
from ..utils .bench_case import get_bench_case_value
43
44
from ..utils .common import convert_to_numpy , custom_format , get_module_members
@@ -511,8 +512,6 @@ def measure_sklearn_estimator(
511
512
bench_case ,
512
513
task ,
513
514
estimator_class ,
514
- estimator_methods ,
515
- estimator_params ,
516
515
):
517
516
enable_modelbuilders = get_bench_case_value (
518
517
bench_case , "algorithm:enable_modelbuilders" , False
@@ -530,17 +529,31 @@ def measure_sklearn_estimator(
530
529
)
531
530
sklearnex_logging_stream = get_sklearnex_logging_stream ()
532
531
532
+ is_dataset_sequence = (
533
+ get_bench_case_value (bench_case , "data:dataset_sequence" ) is not None
534
+ )
535
+ # TODO Consider if it is possible to do without additional dataset loading
536
+ if not is_dataset_sequence :
537
+ dataset_info = get_bench_case_value (bench_case , "data" )
538
+ data , data_descriptor = load_dataset (bench_case , dataset_info )
539
+ assign_case_special_values_on_run (bench_case , data , data_descriptor )
540
+
541
+ # get estimator parameters
542
+ estimator_params = get_bench_case_value (
543
+ bench_case , "algorithm:estimator_params" , dict ()
544
+ )
545
+
546
+ # get estimator methods for measurement
547
+ estimator_methods = get_estimator_methods (bench_case )
548
+
533
549
metrics = dict ()
550
+
534
551
estimator_instance = estimator_class (** estimator_params )
535
552
for stage in estimator_methods .keys ():
536
553
for method in estimator_methods [stage ]:
537
554
if hasattr (estimator_instance , method ):
538
555
method_instance = getattr (estimator_instance , method )
539
556
if method == "partial_fit" :
540
- is_dataset_sequence = (
541
- get_bench_case_value (bench_case , "data:dataset_sequence" )
542
- is not None
543
- )
544
557
if is_dataset_sequence :
545
558
function_to_measure = create_online_function_for_big_data (
546
559
bench_case , estimator_instance , method_instance , stage
@@ -606,14 +619,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
606
619
estimator_class = get_estimator (library_name , estimator_name )
607
620
task = estimator_to_task (estimator_name )
608
621
609
- # get estimator parameters
610
- estimator_params = get_bench_case_value (
611
- bench_case , "algorithm:estimator_params" , dict ()
612
- )
613
-
614
- # get estimator methods for measurement
615
- estimator_methods = get_estimator_methods (bench_case )
616
-
617
622
# benchmark case filtering
618
623
if not bench_case_filter (bench_case , filters ):
619
624
logger .warning ("Benchmarking case was filtered." )
@@ -626,8 +631,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
626
631
bench_case ,
627
632
task ,
628
633
estimator_class ,
629
- estimator_methods ,
630
- estimator_params ,
631
634
)
632
635
633
636
result_template = {
@@ -648,6 +651,7 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
648
651
"training" : data_description ["x_train" ],
649
652
"inference" : data_description ["x_test" ],
650
653
}
654
+ estimator_methods = get_estimator_methods (bench_case )
651
655
for stage in estimator_methods .keys ():
652
656
data_descs [stage ].update (
653
657
{
0 commit comments