|
70 | 70 | [transforms.ToTensor(),
|
71 | 71 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
72 | 72 |
|
| 73 | +batch_size = 4 |
| 74 | + |
73 | 75 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
74 | 76 | download=True, transform=transform)
|
75 |
| -trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, |
| 77 | +trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, |
76 | 78 | shuffle=True, num_workers=2)
|
77 | 79 |
|
78 | 80 | testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
79 | 81 | download=True, transform=transform)
|
80 |
| -testloader = torch.utils.data.DataLoader(testset, batch_size=4, |
| 82 | +testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, |
81 | 83 | shuffle=False, num_workers=2)
|
82 | 84 |
|
83 | 85 | classes = ('plane', 'car', 'bird', 'cat',
|
@@ -106,7 +108,7 @@ def imshow(img):
|
106 | 108 | # show images
|
107 | 109 | imshow(torchvision.utils.make_grid(images))
|
108 | 110 | # print labels
|
109 |
| -print(' '.join('%5s' % classes[labels[j]] for j in range(4))) |
| 111 | +print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size))) |
110 | 112 |
|
111 | 113 |
|
112 | 114 | ########################################################################
|
|
0 commit comments