Skip to content

Commit 778f3c0

Browse files
committed
refactor load_data?
1 parent 5194413 commit 778f3c0

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bench.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,22 +438,23 @@ def load_data(params, generated_data=[], add_dtype=False, label_2d=False,
438438
for element in full_data:
439439
file_arg = f'file_{element}'
440440
# load and convert data from npy/csv file if path is specified
441+
new_dtype = int_dtype if 'y' in element and int_label else params.dtype
441442
if param_vars[file_arg] is not None:
442443
if param_vars[file_arg].name.endswith('.npy'):
443444
data = np.load(param_vars[file_arg].name, allow_pickle=True)
444445
else:
445446
data = read_csv(param_vars[file_arg].name, params)
446447
full_data[element] = convert_data(
447448
data,
448-
int_dtype if 'y' in element and int_label else params.dtype,
449+
new_dtype,
449450
params.data_order, params.data_format
450451
)
451452
if full_data[element] is None:
452453
# generate and convert data if it's marked and path isn't specified
453454
if element in generated_data:
454455
full_data[element] = convert_data(
455456
np.random.rand(*params.shape),
456-
int_dtype if 'y' in element and int_label else params.dtype,
457+
new_dtype,
457458
params.data_order, params.data_format)
458459
else:
459460
# convert existing labels from 1- to 2-dimensional

0 commit comments

Comments
 (0)