Skip to content

Commit a7aa02c

Browse files
justOlehholly1238
andauthored
Update cifar10_tutorial.py (#1258)
small refactoring Co-authored-by: holly1238 <77758406+holly1238@users.noreply.github.com>
1 parent 9b377fd commit a7aa02c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

beginner_source/blitz/cifar10_tutorial.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,16 @@
7070
[transforms.ToTensor(),
7171
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
7272

73+
batch_size = 4
74+
7375
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
7476
download=True, transform=transform)
75-
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
77+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
7678
shuffle=True, num_workers=2)
7779

7880
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
7981
download=True, transform=transform)
80-
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
82+
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
8183
shuffle=False, num_workers=2)
8284

8385
classes = ('plane', 'car', 'bird', 'cat',
@@ -106,7 +108,7 @@ def imshow(img):
106108
# show images
107109
imshow(torchvision.utils.make_grid(images))
108110
# print labels
109-
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
111+
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
110112

111113

112114
########################################################################

0 commit comments

Comments
 (0)