Skip to content

Commit 0b44688

Browse files
committed
refactor?
1 parent 3cbc5a6 commit 0b44688

File tree

1 file changed

+49
-24
lines changed

1 file changed

+49
-24
lines changed

bench.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,31 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
512512
return result
513513

514514

515+
def update_result_dict(result) -> None:
516+
result.update({'time[s]': times[i]})
517+
if metric_type is not None:
518+
if isinstance(metric_type, str):
519+
result.update({f'{metric_type}': metrics[i]})
520+
elif isinstance(metric_type, list):
521+
for ind, val in enumerate(metric_type):
522+
if metrics[ind][i] is not None:
523+
result.update({f'{val}': metrics[ind][i]})
524+
if hasattr(params, 'n_classes'):
525+
result['input_data'].update({'classes': params.n_classes})
526+
if hasattr(params, 'n_clusters'):
527+
if algorithm == 'kmeans':
528+
result['input_data'].update(
529+
{'n_clusters': params.n_clusters})
530+
elif algorithm == 'dbscan':
531+
result.update({'n_clusters': params.n_clusters})
532+
# replace non-string init with string for kmeans benchmarks
533+
if alg_instance is not None:
534+
if 'init' in result['algorithm_parameters'].keys():
535+
if not isinstance(result['algorithm_parameters']['init'], str):
536+
result['algorithm_parameters']['init'] = 'random'
537+
if 'handle' in result['algorithm_parameters'].keys():
538+
del result['algorithm_parameters']['handle']
539+
515540
def print_output(library, algorithm, stages, params, functions,
516541
times, metric_type, metrics, data, alg_instance=None,
517542
alg_params=None):
@@ -521,30 +546,30 @@ def print_output(library, algorithm, stages, params, functions,
521546
for i, stage in enumerate(stages):
522547
result = gen_basic_dict(library, algorithm, stage, params,
523548
data[i], alg_instance, alg_params)
524-
result.update({'time[s]': times[i]})
525-
if metric_type is not None:
526-
if isinstance(metric_type, str):
527-
result.update({f'{metric_type}': metrics[i]})
528-
elif isinstance(metric_type, list):
529-
for ind, val in enumerate(metric_type):
530-
if metrics[ind][i] is not None:
531-
result.update({f'{val}': metrics[ind][i]})
532-
if hasattr(params, 'n_classes'):
533-
result['input_data'].update({'classes': params.n_classes})
534-
if hasattr(params, 'n_clusters'):
535-
if algorithm == 'kmeans':
536-
result['input_data'].update(
537-
{'n_clusters': params.n_clusters})
538-
elif algorithm == 'dbscan':
539-
result.update({'n_clusters': params.n_clusters})
540-
# replace non-string init with string for kmeans benchmarks
541-
if alg_instance is not None:
542-
condition = 'init' in result['algorithm_parameters'].keys() and\
543-
not isinstance(result['algorithm_parameters']['init'], str)
544-
if condition:
545-
result['algorithm_parameters']['init'] = 'random'
546-
if 'handle' in result['algorithm_parameters'].keys():
547-
del result['algorithm_parameters']['handle']
549+
update_result_dict(result)
550+
# result.update({'time[s]': times[i]})
551+
# if metric_type is not None:
552+
# if isinstance(metric_type, str):
553+
# result.update({f'{metric_type}': metrics[i]})
554+
# elif isinstance(metric_type, list):
555+
# for ind, val in enumerate(metric_type):
556+
# if metrics[ind][i] is not None:
557+
# result.update({f'{val}': metrics[ind][i]})
558+
# if hasattr(params, 'n_classes'):
559+
# result['input_data'].update({'classes': params.n_classes})
560+
# if hasattr(params, 'n_clusters'):
561+
# if algorithm == 'kmeans':
562+
# result['input_data'].update(
563+
# {'n_clusters': params.n_clusters})
564+
# elif algorithm == 'dbscan':
565+
# result.update({'n_clusters': params.n_clusters})
566+
# # replace non-string init with string for kmeans benchmarks
567+
# if alg_instance is not None:
568+
# if 'init' in result['algorithm_parameters'].keys():
569+
# if not isinstance(result['algorithm_parameters']['init'], str):
570+
# result['algorithm_parameters']['init'] = 'random'
571+
# if 'handle' in result['algorithm_parameters'].keys():
572+
# del result['algorithm_parameters']['handle']
548573
output.append(result)
549574
print(json.dumps(output, indent=4))
550575

0 commit comments

Comments
 (0)