From 2651a253768f4bd0a65569168ff7e6808cf49f54 Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Tue, 23 Jun 2020 22:14:22 -0400 Subject: [PATCH] Update dist_tuto.rst --- intermediate_source/dist_tuto.rst | 45 +++++++++++++++---------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/intermediate_source/dist_tuto.rst b/intermediate_source/dist_tuto.rst index 3a76bb1dd3c..76538a81c90 100644 --- a/intermediate_source/dist_tuto.rst +++ b/intermediate_source/dist_tuto.rst @@ -394,29 +394,28 @@ using point-to-point collectives. """ Implementation of a ring-reduce with addition. """ def allreduce(send, recv): - rank = dist.get_rank() - size = dist.get_world_size() - send_buff = th.zeros(send.size()) - recv_buff = th.zeros(send.size()) - accum = th.zeros(send.size()) - accum[:] = send[:] - - left = ((rank - 1) + size) % size - right = (rank + 1) % size - - for i in range(size - 1): - if i % 2 == 0: - # Send send_buff - send_req = dist.isend(send_buff, right) - dist.recv(recv_buff, left) - accum[:] += recv[:] - else: - # Send recv_buff - send_req = dist.isend(recv_buff, right) - dist.recv(send_buff, left) - accum[:] += send[:] - send_req.wait() - recv[:] = accum[:] + rank = dist.get_rank() + size = dist.get_world_size() + send_buff = send.clone() + recv_buff = send.clone() + accum = send.clone() + + left = ((rank - 1) + size) % size + right = (rank + 1) % size + + for i in range(size - 1): + if i % 2 == 0: + # Send send_buff + send_req = dist.isend(send_buff, right) + dist.recv(recv_buff, left) + accum[:] += recv_buff[:] + else: + # Send recv_buff + send_req = dist.isend(recv_buff, right) + dist.recv(send_buff, left) + accum[:] += send_buff[:] + send_req.wait() + recv[:] = accum[:] In the above script, the ``allreduce(send, recv)`` function has a slightly different signature than the ones in PyTorch. It takes a