Skip to content

Commit 67032d7

Browse files
author
Jessica Lin
authored
Merge pull request #1039 from Lyken17/master
Fix allreduce implmentation in tutorial
2 parents d40dc05 + 2651a25 commit 67032d7

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

intermediate_source/dist_tuto.rst

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -394,29 +394,28 @@ using point-to-point collectives.
394394
395395
""" Implementation of a ring-reduce with addition. """
396396
def allreduce(send, recv):
397-
rank = dist.get_rank()
398-
size = dist.get_world_size()
399-
send_buff = th.zeros(send.size())
400-
recv_buff = th.zeros(send.size())
401-
accum = th.zeros(send.size())
402-
accum[:] = send[:]
403-
404-
left = ((rank - 1) + size) % size
405-
right = (rank + 1) % size
406-
407-
for i in range(size - 1):
408-
if i % 2 == 0:
409-
# Send send_buff
410-
send_req = dist.isend(send_buff, right)
411-
dist.recv(recv_buff, left)
412-
accum[:] += recv[:]
413-
else:
414-
# Send recv_buff
415-
send_req = dist.isend(recv_buff, right)
416-
dist.recv(send_buff, left)
417-
accum[:] += send[:]
418-
send_req.wait()
419-
recv[:] = accum[:]
397+
rank = dist.get_rank()
398+
size = dist.get_world_size()
399+
send_buff = send.clone()
400+
recv_buff = send.clone()
401+
accum = send.clone()
402+
403+
left = ((rank - 1) + size) % size
404+
right = (rank + 1) % size
405+
406+
for i in range(size - 1):
407+
if i % 2 == 0:
408+
# Send send_buff
409+
send_req = dist.isend(send_buff, right)
410+
dist.recv(recv_buff, left)
411+
accum[:] += recv_buff[:]
412+
else:
413+
# Send recv_buff
414+
send_req = dist.isend(recv_buff, right)
415+
dist.recv(send_buff, left)
416+
accum[:] += send_buff[:]
417+
send_req.wait()
418+
recv[:] = accum[:]
420419
421420
In the above script, the ``allreduce(send, recv)`` function has a
422421
slightly different signature than the ones in PyTorch. It takes a

0 commit comments

Comments
 (0)