Skip to content

Commit dbc4b0f

Browse files
authored
Update 2021-10-26-accelerating-pytorch-with-cuda-graphs.md
1 parent b058daf commit dbc4b0f

File tree

1 file changed

+7
-53
lines changed

1 file changed

+7
-53
lines changed

_posts/2021-10-26-accelerating-pytorch-with-cuda-graphs.md

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ You should try CUDA graphs if all or part of your network is graph-safe (usually
6464

6565
PyTorch exposes graphs via a raw [`torch.cuda.CUDAGraph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph)class and two convenience wrappers, [`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) and [`torch.cuda.make_graphed_callables`](https://pytorch.org/docs/master/generated/torch.cuda.make_graphed_callables.html#torch.cuda.make_graphed_callables).
6666

67+
[`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) is a simple, versatile context manager that captures CUDA work in its context. Before capture, warm up the workload to be captured by running a few eager iterations. Warmup must occur on a side stream. Because the graph reads from and writes to the same memory addresses in every replay, you must maintain long-lived references to tensors that hold input and output data during capture. To run the graph on new input data, copy new data to the capture’s input tensor(s), replay the graph, then read the new output from the capture’s output tensor(s).
68+
69+
If the entire network is capture safe, one can capture and replay the whole network as in the following example.
70+
6771
```python
6872
N, D_in, H, D_out = 640, 4096, 2048, 1024
6973
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
@@ -118,56 +122,6 @@ for data, target in zip(real_inputs, real_targets):
118122
# attributes hold values from computing on this iteration's data.
119123
```
120124

121-
[`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) is a simple, versatile context manager that captures CUDA work in its context. Before capture, warm up the workload to be captured by running a few eager iterations. Warmup must occur on a side stream. Because the graph reads from and writes to the same memory addresses in every replay, you must maintain long-lived references to tensors that hold input and output data during capture. To run the graph on new input data, copy new data to the capture’s input tensor(s), replay the graph, then read the new output from the capture’s output tensor(s).
122-
123-
If the entire network is capture safe, one can capture and replay the whole network as in the following example.
124-
125-
```python
126-
N, D_in, H, D_out = 640, 4096, 2048, 1024
127-
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
128-
torch.nn.Dropout(p=0.2),
129-
torch.nn.Linear(H, D_out),
130-
torch.nn.Dropout(p=0.1)).cuda()
131-
loss_fn = torch.nn.MSELoss()
132-
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
133-
134-
# Placeholders used for capture
135-
static_input = torch.randn(N, D_in, device='cuda')
136-
static_target = torch.randn(N, D_out, device='cuda')
137-
138-
# warmup
139-
# Uses static_input and static_target here for convenience,
140-
# but in a real setting, because the warmup includes optimizer.step()
141-
# you must use a few batches of real data.
142-
s = torch.cuda.Stream()
143-
s.wait_stream(torch.cuda.current_stream())
144-
with torch.cuda.stream(s):
145-
for i in range(3):
146-
optimizer.zero_grad(set_to_none=True)
147-
y_pred = model(static_input)
148-
loss = loss_fn(y_pred, static_target)
149-
loss.backward()
150-
optimizer.step()
151-
torch.cuda.current_stream().wait_stream(s)
152-
153-
# capture
154-
g = torch.cuda.CUDAGraph()
155-
# Sets grads to None before capture, so backward() will create
156-
# .grad attributes with allocations from the graph's private pool
157-
optimizer.zero_grad(set_to_none=True)
158-
with torch.cuda.graph(g):
159-
static_y_pred = model(static_input)
160-
# Fills the graph's input memory with new data to compute on
161-
static_input.copy_(data)
162-
static_target.copy_(target)
163-
# replay() includes forward, backward, and step.
164-
# You don't even need to call optimizer.zero_grad() between iterations
165-
# because the captured backward refills static .grad tensors in place.
166-
g.replay()
167-
# Params have been updated. static_y_pred, static_loss, and .grad
168-
# attributes hold values from computing on this iteration's data.
169-
```
170-
171125
If some of your network is unsafe to capture (e.g., due to dynamic control flow, dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe part(s) eagerly and use [`torch.cuda.make_graphed_callables`](https://pytorch.org/docs/master/generated/torch.cuda.make_graphed_callables.html#torch.cuda.make_graphed_callables) to graph only the capture-safe part(s). This is demonstrated next.
172126

173127
```python
@@ -201,8 +155,8 @@ module2 = torch.nn.Linear(H, D_out).cuda()
201155
module3 = torch.nn.Linear(H, D_out).cuda()
202156

203157
loss_fn = torch.nn.MSELoss()
204-
optimizer = torch.optim.SGD(chain(module1.parameters() +
205-
module2.parameters() +
158+
optimizer = torch.optim.SGD(chain(module1.parameters(),
159+
module2.parameters(),
206160
module3.parameters()),
207161
lr=0.1)
208162

@@ -229,7 +183,7 @@ for data, target in zip(real_inputs, real_targets):
229183
else:
230184
tmp = module3(tmp) # forward ops run as a graph
231185

232-
loss = loss_fn(tmp, y)
186+
loss = loss_fn(tmp, target)
233187
# module2's or module3's (whichever was chosen) backward ops,
234188
# as well as module1's backward ops, run as graphs
235189
loss.backward()

0 commit comments

Comments
 (0)