Skip to content

Commit 6fc17ef

Browse files
committed
refactor
1 parent 786acb4 commit 6fc17ef

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

bench.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -448,25 +448,25 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
448448
int_dtype if 'y' in element and int_label else params.dtype,
449449
params.data_order, params.data_format
450450
)
451-
# generate and convert data if it's marked and path isn't specified
452-
if full_data[element] is None and element in generated_data:
453-
full_data[element] = convert_data(
454-
np.random.rand(*params.shape),
455-
int_dtype if 'y' in element and int_label else params.dtype,
456-
params.data_order, params.data_format)
457-
# convert existing labels from 1- to 2-dimensional
458-
# if it's forced and possible
459-
if full_data[element] is not None and 'y' in element \
460-
and label_2d and hasattr(full_data[element], 'reshape'):
461-
full_data[element] = full_data[element].reshape(
462-
(full_data[element].shape[0], 1))
463-
# add dtype property to data if it's needed and doesn't exist
464-
if full_data[element] is not None and add_dtype and \
465-
not hasattr(full_data[element], 'dtype'):
466-
if hasattr(full_data[element], 'values'):
467-
full_data[element].dtype = full_data[element].values.dtype
468-
elif hasattr(full_data[element], 'dtypes'):
469-
full_data[element].dtype = full_data[element].dtypes[0].type
451+
if full_data[element] is None:
452+
# generate and convert data if it's marked and path isn't specified
453+
if element in generated_data:
454+
full_data[element] = convert_data(
455+
np.random.rand(*params.shape),
456+
int_dtype if 'y' in element and int_label else params.dtype,
457+
params.data_order, params.data_format)
458+
else:
459+
# convert existing labels from 1- to 2-dimensional
460+
# if it's forced and possible
461+
if 'y' in element and label_2d and hasattr(full_data[element], 'reshape'):
462+
full_data[element] = full_data[element].reshape(
463+
(full_data[element].shape[0], 1))
464+
# add dtype property to data if it's needed and doesn't exist
465+
if add_dtype and not hasattr(full_data[element], 'dtype'):
466+
if hasattr(full_data[element], 'values'):
467+
full_data[element].dtype = full_data[element].values.dtype
468+
elif hasattr(full_data[element], 'dtypes'):
469+
full_data[element].dtype = full_data[element].dtypes[0].type
470470

471471
params.dtype = get_dtype(full_data['X_train'])
472472
# add size to parameters which is need for some cases

0 commit comments

Comments
 (0)