From 57947f07a75debd8688f0ac2c9325df58a635a2b Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 10 Oct 2023 15:26:16 -0700 Subject: [PATCH 1/2] small fixes to torch.compile tutorial --- intermediate_source/torch_compile_tutorial.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index 1093f7577e1..36bc231bccf 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -138,21 +138,18 @@ def init_model(): # Note that in the call to ``torch.compile``, we have have the additional # ``mode`` argument, which we will discuss below. -def evaluate(mod, inp): - return mod(inp) - model = init_model() # Reset since we are using a different mode. import torch._dynamo torch._dynamo.reset() -evaluate_opt = torch.compile(evaluate, mode="reduce-overhead") +model_opt = torch.compile(model, mode="reduce-overhead") inp = generate_data(16)[0] with torch.no_grad(): - print("eager:", timed(lambda: evaluate(model, inp))[1]) - print("compile:", timed(lambda: evaluate_opt(model, inp))[1]) + print("eager:", timed(lambda: model(inp))[1]) + print("compile:", timed(lambda: model_opt(inp))[1]) ###################################################################### # Notice that ``torch.compile`` takes a lot longer to complete @@ -166,7 +163,7 @@ def evaluate(mod, inp): for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): - _, eager_time = timed(lambda: evaluate(model, inp)) + _, eager_time = timed(lambda: model(inp)) eager_times.append(eager_time) print(f"eager eval time {i}: {eager_time}") @@ -176,7 +173,7 @@ def evaluate(mod, inp): for i in range(N_ITERS): inp = generate_data(16)[0] with torch.no_grad(): - _, compile_time = timed(lambda: evaluate_opt(model, inp)) + _, compile_time = timed(lambda: model_opt(inp)) compile_times.append(compile_time) print(f"compile eval time {i}: {compile_time}") print("~" * 10) @@ -250,6 +247,10 @@ def train(mod, data): # Again, we can see that ``torch.compile`` takes longer in the first # iteration, as it must compile the model, but in subsequent iterations, we see # significant speedups compared to eager. +# +# We remark that the speedup numbers presented in this tutorial are for +# demonstration purposes only. Official speedup values can be seen at the +# `TorchInductor performance dashboard `__ ###################################################################### # Comparison to TorchScript and FX Tracing From 94fadec7a4c13bf1106ecfef504aaa01d877d881 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 10 Oct 2023 17:32:13 -0700 Subject: [PATCH 2/2] Update intermediate_source/torch_compile_tutorial.py --- intermediate_source/torch_compile_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index 36bc231bccf..b6ac9ee3436 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -250,7 +250,7 @@ def train(mod, data): # # We remark that the speedup numbers presented in this tutorial are for # demonstration purposes only. Official speedup values can be seen at the -# `TorchInductor performance dashboard `__ +# `TorchInductor performance dashboard `__. ###################################################################### # Comparison to TorchScript and FX Tracing