From 5e0868354acafeefda821e3ef58ec3fc3d4a160e Mon Sep 17 00:00:00 2001 From: Oleh <34406675+justOleh@users.noreply.github.com> Date: Mon, 23 Nov 2020 18:07:22 +0200 Subject: [PATCH] Update cifar10_tutorial.py small refactoring --- beginner_source/blitz/cifar10_tutorial.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/beginner_source/blitz/cifar10_tutorial.py b/beginner_source/blitz/cifar10_tutorial.py index 730bf6ac986..9f387c7145c 100644 --- a/beginner_source/blitz/cifar10_tutorial.py +++ b/beginner_source/blitz/cifar10_tutorial.py @@ -70,14 +70,16 @@ [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) +batch_size = 4 + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) -trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, +trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) -testloader = torch.utils.data.DataLoader(testset, batch_size=4, +testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', @@ -106,7 +108,7 @@ def imshow(img): # show images imshow(torchvision.utils.make_grid(images)) # print labels -print(' '.join('%5s' % classes[labels[j]] for j in range(4))) +print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size))) ########################################################################