Skip to content

Commit 0452dd8

Browse files
authored
add detail about set_epoch (#2066)
1 parent 8999538 commit 0452dd8

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

beginner_source/ddp_series_multigpu.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ Distributing input data
123123
+ sampler=DistributedSampler(train_dataset),
124124
)
125125
126+
- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work
127+
properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
128+
129+
.. code:: diff
130+
131+
def _run_epoch(self, epoch):
132+
b_sz = len(next(iter(self.train_data))[0])
133+
+ self.train_data.sampler.set_epoch(epoch)
134+
for source, targets in self.train_data:
135+
...
136+
self._run_batch(source, targets)
137+
126138
127139
Saving model checkpoints
128140
~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)