Skip to content

torch.compile tutorial optimizer update #2161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions intermediate_source/torch_compile_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ see a significant improvement compared to eager.

And indeed, we can see that running our model with ``torch.compile``
results in a significant speedup. On an NVIDIA A100 GPU, we observe a
2.2x speedup. Speedup mainly comes from reducing Python overhead and
2.3x speedup. Speedup mainly comes from reducing Python overhead and
GPU read/writes, and so the observed speedup may vary on factors such as model
architecture and batch size. For example, if a model's architecture is simple
and the amount of data is large, then the bottleneck would be
Expand All @@ -196,16 +196,16 @@ Now, let's consider comparing training.
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()

eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)
opt.zero_grad(True)
_, eager_time = timed(lambda: train(model, inp))
opt.step()
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
print("~" * 10)
Expand All @@ -217,9 +217,7 @@ Now, let's consider comparing training.
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)
opt.zero_grad(True)
_, compile_time = timed(lambda: train_opt(model, inp))
opt.step()
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
print("~" * 10)
Expand All @@ -233,13 +231,7 @@ Now, let's consider comparing training.
Again, we can see that ``torch.compile`` takes longer in the first
iteration, as it must compile the model, but afterward, we see
significant speedups compared to eager. On an NVIDIA A100 GPU, we
observe a 1.8x speedup.

One thing to note is that, as of now, we cannot place optimizer code --
``opt.zero_grad`` and ``opt.step`` -- inside of an optimized function.
The rest of the training loop -- the forward pass and the backward pass --
can be optimized. We are currently working on enabling optimizers to be
compatible with ``torch.compile``.
observe a 2.2x speedup.

Comparison to TorchScript and FX Tracing
-----------------------------------------
Expand Down
16 changes: 4 additions & 12 deletions intermediate_source/torch_compile_tutorial_.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def evaluate(mod, inp):
######################################################################
# And indeed, we can see that running our model with ``torch.compile``
# results in a significant speedup. On an NVIDIA A100 GPU, we observe a
# 2.2x speedup. Speedup mainly comes from reducing Python overhead and
# 2.3x speedup. Speedup mainly comes from reducing Python overhead and
# GPU read/writes, and so the observed speedup may vary on factors such as model
# architecture and batch size. For example, if a model's architecture is simple
# and the amount of data is large, then the bottleneck would be
Expand All @@ -197,16 +197,16 @@ def evaluate(mod, inp):
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()

eager_times = []
for i in range(N_ITERS):
inp = generate_data(16)
opt.zero_grad(True)
_, eager_time = timed(lambda: train(model, inp))
opt.step()
eager_times.append(eager_time)
print(f"eager train time {i}: {eager_time}")
print("~" * 10)
Expand All @@ -218,9 +218,7 @@ def train(mod, data):
compile_times = []
for i in range(N_ITERS):
inp = generate_data(16)
opt.zero_grad(True)
_, compile_time = timed(lambda: train_opt(model, inp))
opt.step()
compile_times.append(compile_time)
print(f"compile train time {i}: {compile_time}")
print("~" * 10)
Expand All @@ -235,13 +233,7 @@ def train(mod, data):
# Again, we can see that ``torch.compile`` takes longer in the first
# iteration, as it must compile the model, but afterward, we see
# significant speedups compared to eager. On an NVIDIA A100 GPU, we
# observe a 1.8x speedup.
#
# One thing to note is that, as of now, we cannot place optimizer code --
# ``opt.zero_grad`` and ``opt.step`` -- inside of an optimized function.
# The rest of the training loop -- the forward pass and the backward pass --
# can be optimized. We are currently working on enabling optimizers to be
# compatible with ``torch.compile``.
# observe a 2.2x speedup.

######################################################################
# Comparison to TorchScript and FX Tracing
Expand Down