diff --git a/beginner_source/ddp_series_multigpu.rst b/beginner_source/ddp_series_multigpu.rst index b2be2d36b44..baf92d8f8af 100644 --- a/beginner_source/ddp_series_multigpu.rst +++ b/beginner_source/ddp_series_multigpu.rst @@ -123,6 +123,18 @@ Distributing input data + sampler=DistributedSampler(train_dataset), ) +- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work + properly across multiple epochs. Otherwise, the same ordering will be used in each epoch. + +.. code:: diff + + def _run_epoch(self, epoch): + b_sz = len(next(iter(self.train_data))[0]) + + self.train_data.sampler.set_epoch(epoch) + for source, targets in self.train_data: + ... + self._run_batch(source, targets) + Saving model checkpoints ~~~~~~~~~~~~~~~~~~~~~~~~