diff --git a/beginner_source/data_loading_tutorial.py b/beginner_source/data_loading_tutorial.py index be3e29f77b3..b31de5088c4 100644 --- a/beginner_source/data_loading_tutorial.py +++ b/beginner_source/data_loading_tutorial.py @@ -379,13 +379,14 @@ def show_landmarks_batch(sample_batched): sample_batched['image'], sample_batched['landmarks'] batch_size = len(images_batch) im_size = images_batch.size(2) + grid_border_size = 2 grid = utils.make_grid(images_batch) plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size): - plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size, - landmarks_batch[i, :, 1].numpy(), + plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size, + landmarks_batch[i, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r') plt.title('Batch from dataloader')