diff --git a/intermediate_source/torch_compile_tutorial.py b/intermediate_source/torch_compile_tutorial.py index 739eeee422d..1093f7577e1 100644 --- a/intermediate_source/torch_compile_tutorial.py +++ b/intermediate_source/torch_compile_tutorial.py @@ -139,8 +139,7 @@ def init_model(): # ``mode`` argument, which we will discuss below. def evaluate(mod, inp): - with torch.no_grad(): - return mod(inp) + return mod(inp) model = init_model() @@ -151,8 +150,9 @@ def evaluate(mod, inp): evaluate_opt = torch.compile(evaluate, mode="reduce-overhead") inp = generate_data(16)[0] -print("eager:", timed(lambda: evaluate(model, inp))[1]) -print("compile:", timed(lambda: evaluate_opt(model, inp))[1]) +with torch.no_grad(): + print("eager:", timed(lambda: evaluate(model, inp))[1]) + print("compile:", timed(lambda: evaluate_opt(model, inp))[1]) ###################################################################### # Notice that ``torch.compile`` takes a lot longer to complete @@ -165,7 +165,8 @@ def evaluate(mod, inp): eager_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] - _, eager_time = timed(lambda: evaluate(model, inp)) + with torch.no_grad(): + _, eager_time = timed(lambda: evaluate(model, inp)) eager_times.append(eager_time) print(f"eager eval time {i}: {eager_time}") @@ -174,7 +175,8 @@ def evaluate(mod, inp): compile_times = [] for i in range(N_ITERS): inp = generate_data(16)[0] - _, compile_time = timed(lambda: evaluate_opt(model, inp)) + with torch.no_grad(): + _, compile_time = timed(lambda: evaluate_opt(model, inp)) compile_times.append(compile_time) print(f"compile eval time {i}: {compile_time}") print("~" * 10) @@ -183,6 +185,7 @@ def evaluate(mod, inp): eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med +assert(speedup > 1) print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10) @@ -239,6 +242,7 @@ def train(mod, data): eager_med = np.median(eager_times) compile_med = np.median(compile_times) speedup = eager_med / compile_med +assert(speedup > 1) print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x") print("~" * 10)