Skip to content

Commit 013fa3c

Browse files
committed
Add assigning special values to estimator params
1 parent a9214c7 commit 013fa3c

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

sklbench/benchmarks/sklearn_estimator.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939

4040
from ..datasets import load_dataset
41+
from ..datasets.special_params import assign_case_special_values_on_run
4142
from ..datasets.transformer import split_and_transform_data
4243
from ..utils.bench_case import get_bench_case_value
4344
from ..utils.common import convert_to_numpy, custom_format, get_module_members
@@ -511,8 +512,6 @@ def measure_sklearn_estimator(
511512
bench_case,
512513
task,
513514
estimator_class,
514-
estimator_methods,
515-
estimator_params,
516515
):
517516
enable_modelbuilders = get_bench_case_value(
518517
bench_case, "algorithm:enable_modelbuilders", False
@@ -530,17 +529,31 @@ def measure_sklearn_estimator(
530529
)
531530
sklearnex_logging_stream = get_sklearnex_logging_stream()
532531

532+
is_dataset_sequence = (
533+
get_bench_case_value(bench_case, "data:dataset_sequence") is not None
534+
)
535+
# TODO Consider if it is possible to do without additional dataset loading
536+
if not is_dataset_sequence:
537+
dataset_info = get_bench_case_value(bench_case, "data")
538+
data, data_descriptor = load_dataset(bench_case, dataset_info)
539+
assign_case_special_values_on_run(bench_case, data, data_descriptor)
540+
541+
# get estimator parameters
542+
estimator_params = get_bench_case_value(
543+
bench_case, "algorithm:estimator_params", dict()
544+
)
545+
546+
# get estimator methods for measurement
547+
estimator_methods = get_estimator_methods(bench_case)
548+
533549
metrics = dict()
550+
534551
estimator_instance = estimator_class(**estimator_params)
535552
for stage in estimator_methods.keys():
536553
for method in estimator_methods[stage]:
537554
if hasattr(estimator_instance, method):
538555
method_instance = getattr(estimator_instance, method)
539556
if method == "partial_fit":
540-
is_dataset_sequence = (
541-
get_bench_case_value(bench_case, "data:dataset_sequence")
542-
is not None
543-
)
544557
if is_dataset_sequence:
545558
function_to_measure = create_online_function_for_big_data(
546559
bench_case, estimator_instance, method_instance, stage
@@ -606,14 +619,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
606619
estimator_class = get_estimator(library_name, estimator_name)
607620
task = estimator_to_task(estimator_name)
608621

609-
# get estimator parameters
610-
estimator_params = get_bench_case_value(
611-
bench_case, "algorithm:estimator_params", dict()
612-
)
613-
614-
# get estimator methods for measurement
615-
estimator_methods = get_estimator_methods(bench_case)
616-
617622
# benchmark case filtering
618623
if not bench_case_filter(bench_case, filters):
619624
logger.warning("Benchmarking case was filtered.")
@@ -626,8 +631,6 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
626631
bench_case,
627632
task,
628633
estimator_class,
629-
estimator_methods,
630-
estimator_params,
631634
)
632635

633636
result_template = {
@@ -648,6 +651,7 @@ def main(bench_case: BenchCase, filters: List[BenchCase]):
648651
"training": data_description["x_train"],
649652
"inference": data_description["x_test"],
650653
}
654+
estimator_methods = get_estimator_methods(bench_case)
651655
for stage in estimator_methods.keys():
652656
data_descs[stage].update(
653657
{

0 commit comments

Comments
 (0)