File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -598,19 +598,20 @@ accumulate to the same set of ``Tensors``.
598
598
599
599
# train for 10 iterations
600
600
for epoch in range (10 ):
601
- # create distributed autograd context
602
601
for data, target in get_next_batch():
603
- with dist_autograd.context():
602
+ # create distributed autograd context
603
+ with dist_autograd.context() as context_id:
604
604
hidden[0 ].detach_()
605
605
hidden[1 ].detach_()
606
606
output, hidden = model(data, hidden)
607
607
loss = criterion(output, target)
608
608
# run distributed backward pass
609
- dist_autograd.backward([loss])
609
+ dist_autograd.backward(context_id, [loss])
610
610
# run distributed optimizer
611
- opt.step()
612
- # not necessary to zero grads as each iteration creates a different
613
- # distributed autograd context which hosts different grads
611
+ opt.step(context_id)
612
+ # not necessary to zero grads since they are
613
+ # accumulated into the distributed autograd context
614
+ # which is reset every iteration.
614
615
print (" Training epoch {} " .format(epoch))
615
616
616
617
You can’t perform that action at this time.
0 commit comments