@@ -458,9 +458,7 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
458
458
else :
459
459
# convert existing labels from 1- to 2-dimensional
460
460
# if it's forced and possible
461
- condition1 : bool = 'y' in element and label_2d
462
- condition1 = condition1 and hasattr (full_data [element ], 'reshape' )
463
- if condition1 :
461
+ if 'y' in element and label_2d and hasattr (full_data [element ], 'reshape' ):
464
462
full_data [element ] = full_data [element ].reshape (
465
463
(full_data [element ].shape [0 ], 1 ))
466
464
# add dtype property to data if it's needed and doesn't exist
@@ -482,8 +480,7 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
482
480
return tuple (full_data .values ())
483
481
484
482
485
- def gen_basic_dict (library , algorithm , stage , params , data , alg_instance = None ,
486
- alg_params = None ):
483
+ def gen_basic_dict (library , algorithm , stage , params , data ):
487
484
result = {
488
485
'library' : library ,
489
486
'algorithm' : algorithm ,
@@ -498,6 +495,9 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
498
495
'columns' : data .shape [1 ]
499
496
}
500
497
}
498
+ return result
499
+
500
+ def update_algorithm_parameters (result , alg_instance = None , alg_params = None ):
501
501
result ['algorithm_parameters' ] = {}
502
502
if alg_instance is not None :
503
503
if 'Booster' in str (type (alg_instance )):
@@ -509,8 +509,15 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
509
509
alg_instance_params ['dtype' ] = str (
510
510
alg_instance_params ['dtype' ])
511
511
result ['algorithm_parameters' ].update (alg_instance_params )
512
+ if 'init' in result ['algorithm_parameters' ]:
513
+ if not isinstance (result ['algorithm_parameters' ]['init' ], str ):
514
+ result ['algorithm_parameters' ]['init' ] = 'random'
512
515
if alg_params is not None :
513
516
result ['algorithm_parameters' ].update (alg_params )
517
+ if 'init' in result ['algorithm_parameters' ].keys ():
518
+ if not isinstance (result ['algorithm_parameters' ]['init' ], str ):
519
+ result ['algorithm_parameters' ]['init' ] = 'random'
520
+ result ['algorithm_parameters' ].pop ('handle' ,None )
514
521
return result
515
522
516
523
@@ -521,8 +528,7 @@ def print_output(library, algorithm, stages, params, functions,
521
528
return
522
529
output = []
523
530
for i , stage in enumerate (stages ):
524
- result = gen_basic_dict (library , algorithm , stage , params ,
525
- data [i ], alg_instance , alg_params )
531
+ result = gen_basic_dict (library , algorithm , stage , params , data [i ])
526
532
result .update ({'time[s]' : times [i ]})
527
533
if isinstance (metric_type , str ):
528
534
result .update ({f'{ metric_type } ' : metrics [i ]})
@@ -539,11 +545,7 @@ def print_output(library, algorithm, stages, params, functions,
539
545
elif algorithm == 'dbscan' :
540
546
result .update ({'n_clusters' : params .n_clusters })
541
547
# replace non-string init with string for kmeans benchmarks
542
- if alg_instance is not None :
543
- if 'init' in result ['algorithm_parameters' ].keys ():
544
- if not isinstance (result ['algorithm_parameters' ]['init' ], str ):
545
- result ['algorithm_parameters' ]['init' ] = 'random'
546
- result ['algorithm_parameters' ].pop ('handle' ,None )
548
+ result = update_algorithm_parameters (result , alg_instance , alg_params )
547
549
output .append (result )
548
550
print (json .dumps (output , indent = 4 ))
549
551
0 commit comments