Skip to content

Commit fa3f1a5

Browse files
authored
Merge pull request #884 from rohan-varma/fix_tut
Update RPC tutorial with new calls to dist autograd and optmizer
2 parents c8591e7 + 082f631 commit fa3f1a5

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

intermediate_source/rpc_tutorial.rst

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,19 +598,20 @@ accumulate to the same set of ``Tensors``.
598598
599599
# train for 10 iterations
600600
for epoch in range(10):
601-
# create distributed autograd context
602601
for data, target in get_next_batch():
603-
with dist_autograd.context():
602+
# create distributed autograd context
603+
with dist_autograd.context() as context_id:
604604
hidden[0].detach_()
605605
hidden[1].detach_()
606606
output, hidden = model(data, hidden)
607607
loss = criterion(output, target)
608608
# run distributed backward pass
609-
dist_autograd.backward([loss])
609+
dist_autograd.backward(context_id, [loss])
610610
# run distributed optimizer
611-
opt.step()
612-
# not necessary to zero grads as each iteration creates a different
613-
# distributed autograd context which hosts different grads
611+
opt.step(context_id)
612+
# not necessary to zero grads since they are
613+
# accumulated into the distributed autograd context
614+
# which is reset every iteration.
614615
print("Training epoch {}".format(epoch))
615616
616617

0 commit comments

Comments
 (0)