Skip to content

Commit 97ac576

Browse files
authored
Update references/classification/train.py
1 parent 03a3ae2 commit 97ac576

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

references/classification/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ def load_data(traindir, valdir, args):
160160
if args.cache_dataset and os.path.exists(cache_path):
161161
# Attention, as the transforms are also cached!
162162
print(f"Loading dataset_test from {cache_path}")
163-
dataset_test, _ = torch.load(cache_path, weights_only=True)
163+
# TODO: this could probably be weights_only=True
164+
dataset_test, _ = torch.load(cache_path, weights_only=False)
164165
else:
165166
if args.weights and args.test_only:
166167
weights = torchvision.models.get_weight(args.weights)

0 commit comments

Comments
 (0)