Skip to content

Commit 057e4d1

Browse files
committed
[doc] minor fixups to DDP tutorial
Summary: Add "set_device" call to keep things consistent between all DDP tutorials. This was inspired by the following change in the PyTorch repo: pytorch/examples#1285 (review) Test Plan: Ran tutorial with the applied changes and we see: """ Running basic DDP example on rank 3. Running basic DDP example on rank 1. Running basic DDP example on rank 2. Running basic DDP example on rank 0. Finished running basic DDP example on rank 0. Finished running basic DDP example on rank 1. Finished running basic DDP example on rank 3. Finished running basic DDP example on rank 2. Running DDP checkpoint example on rank 2. Running DDP checkpoint example on rank 1. Running DDP checkpoint example on rank 0. Running DDP checkpoint example on rank 3. Finished DDP checkpoint example on rank 0. Finished DDP checkpoint example on rank 3. Finished DDP checkpoint example on rank 1. Finished DDP checkpoint example on rank 2. Running DDP with model parallel example on rank 0. Running DDP with model parallel example on rank 1. Finished running DDP with model parallel example on rank 0. Finished running DDP with model parallel example on rank 1. """
1 parent 904ca90 commit 057e4d1

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

intermediate_source/ddp_tutorial.rst

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ be found in
9999
os.environ['MASTER_ADDR'] = 'localhost'
100100
os.environ['MASTER_PORT'] = '12355'
101101
102+
# set the device id for this process
103+
torch.cuda.set_device(rank)
104+
102105
# initialize the process group
103106
dist.init_process_group("gloo", rank=rank, world_size=world_size)
104107
@@ -141,6 +144,7 @@ different DDP processes starting from different initial model parameter values.
141144
optimizer.step()
142145
143146
cleanup()
147+
print(f"Finished running basic DDP example on rank {rank}.")
144148
145149
146150
def run_demo(demo_fn, world_size):
@@ -182,7 +186,7 @@ for more details. When using DDP, one optimization is to save the model in
182186
only one process and then load it to all processes, reducing write overhead.
183187
This is correct because all processes start from the same parameters and
184188
gradients are synchronized in backward passes, and hence optimizers should keep
185-
setting parameters to the same values. If you use this optimization, make sure no process starts
189+
setting parameters to the same values. If you use this optimization, make sure no process starts
186190
loading before the saving is finished. Additionally, when
187191
loading the module, you need to provide an appropriate ``map_location``
188192
argument to prevent a process from stepping into others' devices. If ``map_location``
@@ -218,7 +222,7 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
218222
219223
loss_fn = nn.MSELoss()
220224
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
221-
225+
222226
optimizer.zero_grad()
223227
outputs = ddp_model(torch.randn(20, 10))
224228
labels = torch.randn(20, 5).to(rank)
@@ -234,6 +238,7 @@ and elasticity support, please refer to `TorchElastic <https://pytorch.org/elast
234238
os.remove(CHECKPOINT_PATH)
235239
236240
cleanup()
241+
print(f"Finished running DDP checkpoint example on rank {rank}.")
237242
238243
Combining DDP with Model Parallelism
239244
------------------------------------
@@ -285,6 +290,7 @@ either the application or the model ``forward()`` method.
285290
optimizer.step()
286291
287292
cleanup()
293+
print(f"Finished running DDP with model parallel example on rank {rank}.")
288294
289295
290296
if __name__ == "__main__":
@@ -323,10 +329,13 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
323329
324330
325331
def demo_basic():
326-
dist.init_process_group("nccl")
327332
rank = dist.get_rank()
333+
torch.cuda.set_device(rank)
334+
335+
dist.init_process_group("nccl")
336+
328337
print(f"Start running basic DDP example on rank {rank}.")
329-
338+
330339
# create model and move it to GPU with id rank
331340
device_id = rank % torch.cuda.device_count()
332341
model = ToyModel().to(device_id)
@@ -340,23 +349,24 @@ Let's still use the Toymodel example and create a file named ``elastic_ddp.py``.
340349
labels = torch.randn(20, 5).to(device_id)
341350
loss_fn(outputs, labels).backward()
342351
optimizer.step()
343-
dist.destroy_process_group()
344-
352+
cleanup()
353+
print(f"Finished running basic DDP example on rank {rank}.")
354+
345355
if __name__ == "__main__":
346356
demo_basic()
347357
348-
One can then run a `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command
358+
One can then run a `torch elastic/torchrun <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ command
349359
on all nodes to initialize the DDP job created above:
350360

351361
.. code:: bash
352362
353363
torchrun --nnodes=2 --nproc_per_node=8 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR:29400 elastic_ddp.py
354364
355-
We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
365+
We are running the DDP script on two hosts, and each host we run with 8 processes, aka, we
356366
are running it on 16 GPUs. Note that ``$MASTER_ADDR`` must be the same across all nodes.
357367

358-
Here torchrun will launch 8 process and invoke ``elastic_ddp.py``
359-
on each process on the node it is launched on, but user also needs to apply cluster
368+
Here torchrun will launch 8 process and invoke ``elastic_ddp.py``
369+
on each process on the node it is launched on, but user also needs to apply cluster
360370
management tools like slurm to actually run this command on 2 nodes.
361371

362372
For example, on a SLURM enabled cluster, we can write a script to run the command above
@@ -371,5 +381,5 @@ Then we can just run this script using the SLURM command: ``srun --nodes=2 ./tor
371381
Of course, this is just an example; you can choose your own cluster scheduling tools
372382
to initiate the torchrun job.
373383

374-
For more information about Elastic run, one can check this
384+
For more information about Elastic run, one can check this
375385
`quick start document <https://pytorch.org/docs/stable/elastic/quickstart.html>`__ to learn more.

0 commit comments

Comments
 (0)