Skip to content

Commit 4693362

Browse files
committed
refactor?
1 parent 3509b1d commit 4693362

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

bench.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,7 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
458458
else:
459459
# convert existing labels from 1- to 2-dimensional
460460
# if it's forced and possible
461-
condition1: bool = 'y' in element and label_2d
462-
condition1 = condition1 and hasattr(full_data[element], 'reshape')
463-
if condition1:
461+
if 'y' in element and label_2d and hasattr(full_data[element], 'reshape'):
464462
full_data[element] = full_data[element].reshape(
465463
(full_data[element].shape[0], 1))
466464
# add dtype property to data if it's needed and doesn't exist
@@ -482,8 +480,7 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
482480
return tuple(full_data.values())
483481

484482

485-
def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
486-
alg_params=None):
483+
def gen_basic_dict(library, algorithm, stage, params, data):
487484
result = {
488485
'library': library,
489486
'algorithm': algorithm,
@@ -498,6 +495,9 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
498495
'columns': data.shape[1]
499496
}
500497
}
498+
return result
499+
500+
def update_algorithm_parameters(result, alg_instance=None, alg_params=None):
501501
result['algorithm_parameters'] = {}
502502
if alg_instance is not None:
503503
if 'Booster' in str(type(alg_instance)):
@@ -509,8 +509,15 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
509509
alg_instance_params['dtype'] = str(
510510
alg_instance_params['dtype'])
511511
result['algorithm_parameters'].update(alg_instance_params)
512+
if 'init' in result['algorithm_parameters']:
513+
if not isinstance(result['algorithm_parameters']['init'], str):
514+
result['algorithm_parameters']['init'] = 'random'
512515
if alg_params is not None:
513516
result['algorithm_parameters'].update(alg_params)
517+
if 'init' in result['algorithm_parameters'].keys():
518+
if not isinstance(result['algorithm_parameters']['init'], str):
519+
result['algorithm_parameters']['init'] = 'random'
520+
result['algorithm_parameters'].pop('handle',None)
514521
return result
515522

516523

@@ -521,8 +528,7 @@ def print_output(library, algorithm, stages, params, functions,
521528
return
522529
output = []
523530
for i, stage in enumerate(stages):
524-
result = gen_basic_dict(library, algorithm, stage, params,
525-
data[i], alg_instance, alg_params)
531+
result = gen_basic_dict(library, algorithm, stage, params, data[i])
526532
result.update({'time[s]': times[i]})
527533
if isinstance(metric_type, str):
528534
result.update({f'{metric_type}': metrics[i]})
@@ -539,11 +545,7 @@ def print_output(library, algorithm, stages, params, functions,
539545
elif algorithm == 'dbscan':
540546
result.update({'n_clusters': params.n_clusters})
541547
# replace non-string init with string for kmeans benchmarks
542-
if alg_instance is not None:
543-
if 'init' in result['algorithm_parameters'].keys():
544-
if not isinstance(result['algorithm_parameters']['init'], str):
545-
result['algorithm_parameters']['init'] = 'random'
546-
result['algorithm_parameters'].pop('handle',None)
548+
result = update_algorithm_parameters(result, alg_instance, alg_params)
547549
output.append(result)
548550
print(json.dumps(output, indent=4))
549551

0 commit comments

Comments
 (0)