diff --git a/intermediate_source/rpc_tutorial.rst b/intermediate_source/rpc_tutorial.rst index 81111607de9..cd883930030 100644 --- a/intermediate_source/rpc_tutorial.rst +++ b/intermediate_source/rpc_tutorial.rst @@ -598,19 +598,20 @@ accumulate to the same set of ``Tensors``. # train for 10 iterations for epoch in range(10): - # create distributed autograd context for data, target in get_next_batch(): - with dist_autograd.context(): + # create distributed autograd context + with dist_autograd.context() as context_id: hidden[0].detach_() hidden[1].detach_() output, hidden = model(data, hidden) loss = criterion(output, target) # run distributed backward pass - dist_autograd.backward([loss]) + dist_autograd.backward(context_id, [loss]) # run distributed optimizer - opt.step() - # not necessary to zero grads as each iteration creates a different - # distributed autograd context which hosts different grads + opt.step(context_id) + # not necessary to zero grads since they are + # accumulated into the distributed autograd context + # which is reset every iteration. print("Training epoch {}".format(epoch))