You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2021-10-26-accelerating-pytorch-with-cuda-graphs.md
+7-53Lines changed: 7 additions & 53 deletions
Original file line number
Diff line number
Diff line change
@@ -64,6 +64,10 @@ You should try CUDA graphs if all or part of your network is graph-safe (usually
64
64
65
65
PyTorch exposes graphs via a raw [`torch.cuda.CUDAGraph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph)class and two convenience wrappers, [`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) and [`torch.cuda.make_graphed_callables`](https://pytorch.org/docs/master/generated/torch.cuda.make_graphed_callables.html#torch.cuda.make_graphed_callables).
66
66
67
+
[`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) is a simple, versatile context manager that captures CUDA work in its context. Before capture, warm up the workload to be captured by running a few eager iterations. Warmup must occur on a side stream. Because the graph reads from and writes to the same memory addresses in every replay, you must maintain long-lived references to tensors that hold input and output data during capture. To run the graph on new input data, copy new data to the capture’s input tensor(s), replay the graph, then read the new output from the capture’s output tensor(s).
68
+
69
+
If the entire network is capture safe, one can capture and replay the whole network as in the following example.
70
+
67
71
```python
68
72
N, D_in, H, D_out =640, 4096, 2048, 1024
69
73
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
@@ -118,56 +122,6 @@ for data, target in zip(real_inputs, real_targets):
118
122
# attributes hold values from computing on this iteration's data.
119
123
```
120
124
121
-
[`torch.cuda.graph`](https://pytorch.org/docs/master/generated/torch.cuda.graph.html#torch.cuda.graph) is a simple, versatile context manager that captures CUDA work in its context. Before capture, warm up the workload to be captured by running a few eager iterations. Warmup must occur on a side stream. Because the graph reads from and writes to the same memory addresses in every replay, you must maintain long-lived references to tensors that hold input and output data during capture. To run the graph on new input data, copy new data to the capture’s input tensor(s), replay the graph, then read the new output from the capture’s output tensor(s).
122
-
123
-
If the entire network is capture safe, one can capture and replay the whole network as in the following example.
124
-
125
-
```python
126
-
N, D_in, H, D_out =640, 4096, 2048, 1024
127
-
model = torch.nn.Sequential(torch.nn.Linear(D_in, H),
# Uses static_input and static_target here for convenience,
140
-
# but in a real setting, because the warmup includes optimizer.step()
141
-
# you must use a few batches of real data.
142
-
s = torch.cuda.Stream()
143
-
s.wait_stream(torch.cuda.current_stream())
144
-
with torch.cuda.stream(s):
145
-
for i inrange(3):
146
-
optimizer.zero_grad(set_to_none=True)
147
-
y_pred = model(static_input)
148
-
loss = loss_fn(y_pred, static_target)
149
-
loss.backward()
150
-
optimizer.step()
151
-
torch.cuda.current_stream().wait_stream(s)
152
-
153
-
# capture
154
-
g = torch.cuda.CUDAGraph()
155
-
# Sets grads to None before capture, so backward() will create
156
-
# .grad attributes with allocations from the graph's private pool
157
-
optimizer.zero_grad(set_to_none=True)
158
-
with torch.cuda.graph(g):
159
-
static_y_pred = model(static_input)
160
-
# Fills the graph's input memory with new data to compute on
161
-
static_input.copy_(data)
162
-
static_target.copy_(target)
163
-
# replay() includes forward, backward, and step.
164
-
# You don't even need to call optimizer.zero_grad() between iterations
165
-
# because the captured backward refills static .grad tensors in place.
166
-
g.replay()
167
-
# Params have been updated. static_y_pred, static_loss, and .grad
168
-
# attributes hold values from computing on this iteration's data.
169
-
```
170
-
171
125
If some of your network is unsafe to capture (e.g., due to dynamic control flow, dynamic shapes, CPU syncs, or essential CPU-side logic), you can run the unsafe part(s) eagerly and use [`torch.cuda.make_graphed_callables`](https://pytorch.org/docs/master/generated/torch.cuda.make_graphed_callables.html#torch.cuda.make_graphed_callables) to graph only the capture-safe part(s). This is demonstrated next.
0 commit comments