@@ -394,29 +394,28 @@ using point-to-point collectives.
394
394
395
395
""" Implementation of a ring-reduce with addition. """
396
396
def allreduce (send , recv ):
397
- rank = dist.get_rank()
398
- size = dist.get_world_size()
399
- send_buff = th.zeros(send.size())
400
- recv_buff = th.zeros(send.size())
401
- accum = th.zeros(send.size())
402
- accum[:] = send[:]
403
-
404
- left = ((rank - 1 ) + size) % size
405
- right = (rank + 1 ) % size
406
-
407
- for i in range (size - 1 ):
408
- if i % 2 == 0 :
409
- # Send send_buff
410
- send_req = dist.isend(send_buff, right)
411
- dist.recv(recv_buff, left)
412
- accum[:] += recv[:]
413
- else :
414
- # Send recv_buff
415
- send_req = dist.isend(recv_buff, right)
416
- dist.recv(send_buff, left)
417
- accum[:] += send[:]
418
- send_req.wait()
419
- recv[:] = accum[:]
397
+ rank = dist.get_rank()
398
+ size = dist.get_world_size()
399
+ send_buff = send.clone()
400
+ recv_buff = send.clone()
401
+ accum = send.clone()
402
+
403
+ left = ((rank - 1 ) + size) % size
404
+ right = (rank + 1 ) % size
405
+
406
+ for i in range (size - 1 ):
407
+ if i % 2 == 0 :
408
+ # Send send_buff
409
+ send_req = dist.isend(send_buff, right)
410
+ dist.recv(recv_buff, left)
411
+ accum[:] += recv_buff[:]
412
+ else :
413
+ # Send recv_buff
414
+ send_req = dist.isend(recv_buff, right)
415
+ dist.recv(send_buff, left)
416
+ accum[:] += send_buff[:]
417
+ send_req.wait()
418
+ recv[:] = accum[:]
420
419
421
420
In the above script, the ``allreduce(send, recv) `` function has a
422
421
slightly different signature than the ones in PyTorch. It takes a
0 commit comments