Skip to content

Commit 3b67cc4

Browse files
committed
refactoring
1 parent 63defad commit 3b67cc4

File tree

2 files changed

+37
-38
lines changed

2 files changed

+37
-38
lines changed

bench.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,13 @@ def convert_data(data, dtype, data_order, data_format):
389389
# Secondly, change format of data
390390
if data_format == 'numpy':
391391
return data
392-
elif data_format == 'pandas':
392+
if data_format == 'pandas':
393393
import pandas as pd
394394

395395
if data.ndim == 1:
396396
return pd.Series(data)
397-
else:
398-
return pd.DataFrame(data)
399-
elif data_format == 'cudf':
397+
return pd.DataFrame(data)
398+
if data_format == 'cudf':
400399
import cudf
401400
import pandas as pd
402401

@@ -516,36 +515,36 @@ def gen_basic_dict(library, algorithm, stage, params, data, alg_instance=None,
516515
def print_output(library, algorithm, stages, params, functions,
517516
times, metric_type, metrics, data, alg_instance=None,
518517
alg_params=None):
519-
if params.output_format == 'json':
520-
output = []
521-
for i, stage in enumerate(stages):
522-
result = gen_basic_dict(library, algorithm, stage, params,
523-
data[i], alg_instance, alg_params)
524-
result.update({'time[s]': times[i]})
525-
if metric_type is not None:
526-
if isinstance(metric_type, str):
527-
result.update({f'{metric_type}': metrics[i]})
528-
elif isinstance(metric_type, list):
529-
for ind, val in enumerate(metric_type):
530-
if metrics[ind][i] is not None:
531-
result.update({f'{val}': metrics[ind][i]})
532-
if hasattr(params, 'n_classes'):
533-
result['input_data'].update({'classes': params.n_classes})
534-
if hasattr(params, 'n_clusters'):
535-
if algorithm == 'kmeans':
536-
result['input_data'].update(
537-
{'n_clusters': params.n_clusters})
538-
elif algorithm == 'dbscan':
539-
result.update({'n_clusters': params.n_clusters})
540-
# replace non-string init with string for kmeans benchmarks
541-
if alg_instance is not None:
542-
if 'init' in result['algorithm_parameters'].keys():
543-
if not isinstance(result['algorithm_parameters']['init'], str):
544-
result['algorithm_parameters']['init'] = 'random'
545-
if 'handle' in result['algorithm_parameters'].keys():
546-
del result['algorithm_parameters']['handle']
547-
output.append(result)
548-
print(json.dumps(output, indent=4))
518+
if params.output_format != 'json': return
519+
output = []
520+
for i, stage in enumerate(stages):
521+
result = gen_basic_dict(library, algorithm, stage, params,
522+
data[i], alg_instance, alg_params)
523+
result.update({'time[s]': times[i]})
524+
if metric_type is not None:
525+
if isinstance(metric_type, str):
526+
result.update({f'{metric_type}': metrics[i]})
527+
elif isinstance(metric_type, list):
528+
for ind, val in enumerate(metric_type):
529+
if metrics[ind][i] is not None:
530+
result.update({f'{val}': metrics[ind][i]})
531+
if hasattr(params, 'n_classes'):
532+
result['input_data'].update({'classes': params.n_classes})
533+
if hasattr(params, 'n_clusters'):
534+
if algorithm == 'kmeans':
535+
result['input_data'].update(
536+
{'n_clusters': params.n_clusters})
537+
elif algorithm == 'dbscan':
538+
result.update({'n_clusters': params.n_clusters})
539+
# replace non-string init with string for kmeans benchmarks
540+
if alg_instance is not None:
541+
if 'init' in result['algorithm_parameters'].keys() and \
542+
not isinstance(result['algorithm_parameters']['init'], str):
543+
result['algorithm_parameters']['init'] = 'random'
544+
if 'handle' in result['algorithm_parameters'].keys():
545+
del result['algorithm_parameters']['handle']
546+
output.append(result)
547+
print(json.dumps(output, indent=4))
549548

550549

551550
def run_with_context(params, function):

utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,11 @@ def generate_cases(params: Dict[str, Union[List[Any], Any]]) -> List[str]:
175175
commands *= len(values)
176176
dashes = '-' if len(param) == 1 else '--'
177177
for command_num in range(prev_len):
178-
for value_num in range(len(values)):
179-
commands[prev_len * value_num + command_num] += ' ' + \
180-
dashes + param + ' ' + str(values[value_num])
178+
for idx, val in enumerate(values):
179+
commands[prev_len * idx + command_num] += ' ' + \
180+
dashes + param + ' ' + str(val)
181181
else:
182182
dashes = '-' if len(param) == 1 else '--'
183-
for command_num in range(len(commands)):
183+
for command_num,_ in enumerate(commands):
184184
commands[command_num] += ' ' + dashes + param + ' ' + str(values)
185185
return commands

0 commit comments

Comments
 (0)