Skip to content

Commit 03a3ae2

Browse files
authored
Update references/classification/train.py
1 parent 0d99735 commit 03a3ae2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

references/classification/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def load_data(traindir, valdir, args):
127127
if args.cache_dataset and os.path.exists(cache_path):
128128
# Attention, as the transforms are also cached!
129129
print(f"Loading dataset_train from {cache_path}")
130-
dataset, _ = torch.load(cache_path, weights_only=True)
130+
# TODO: this could probably be weights_only=True
131+
dataset, _ = torch.load(cache_path, weights_only=False)
131132
else:
132133
# We need a default value for the variables below because args may come
133134
# from train_quantization.py which doesn't define them.

0 commit comments

Comments
 (0)