Skip to content

Commit 733b1ec

Browse files
Add torch.compiler.set_stance tutorial (#3260)
* add torch.compiler.set_stance tutorial --------- Co-authored-by: William Wen <williamwen@meta.com>
1 parent db9a867 commit 733b1ec

File tree

2 files changed

+251
-0
lines changed

2 files changed

+251
-0
lines changed

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
122122
:link: ../recipes/torch_compile_backend_ipex.html
123123
:tags: Basics
124124

125+
.. customcarditem::
126+
:header: Dynamic Compilation Control with ``torch.compiler.set_stance``
127+
:card_description: Learn how to use torch.compiler.set_stance
128+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
129+
:link: ../recipes/torch_compiler_set_stance_tutorial.html
130+
:tags: Compiler
131+
125132
.. customcarditem::
126133
:header: Reasoning about Shapes in PyTorch
127134
:card_description: Learn how to use the meta device to reason about shapes in your model.
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Dynamic Compilation Control with ``torch.compiler.set_stance``
5+
=========================================================================
6+
**Author:** `William Wen <https://github.com/williamwen42>`_
7+
"""
8+
9+
######################################################################
10+
# ``torch.compiler.set_stance`` is a ``torch.compiler`` API that
11+
# enables you to change the behavior of ``torch.compile`` across different
12+
# calls to your model without having to reapply ``torch.compile`` to your model.
13+
#
14+
# This recipe provides some examples on how to use ``torch.compiler.set_stance``.
15+
#
16+
#
17+
# .. contents::
18+
# :local:
19+
#
20+
# Prerequisites
21+
# ---------------
22+
#
23+
# - ``torch >= 2.6``
24+
25+
######################################################################
26+
# Description
27+
# -----------
28+
# ``torch.compile.set_stance`` can be used as a decorator, context manager, or raw function
29+
# to change the behavior of ``torch.compile`` across different calls to your model.
30+
#
31+
# In the example below, the ``"force_eager"`` stance ignores all ``torch.compile`` directives.
32+
33+
import torch
34+
35+
36+
@torch.compile
37+
def foo(x):
38+
if torch.compiler.is_compiling():
39+
# torch.compile is active
40+
return x + 1
41+
else:
42+
# torch.compile is not active
43+
return x - 1
44+
45+
46+
inp = torch.zeros(3)
47+
48+
print(foo(inp)) # compiled, prints 1
49+
50+
######################################################################
51+
# Sample decorator usage
52+
53+
54+
@torch.compiler.set_stance("force_eager")
55+
def bar(x):
56+
# force disable the compiler
57+
return foo(x)
58+
59+
60+
print(bar(inp)) # not compiled, prints -1
61+
62+
######################################################################
63+
# Sample context manager usage
64+
65+
with torch.compiler.set_stance("force_eager"):
66+
print(foo(inp)) # not compiled, prints -1
67+
68+
######################################################################
69+
# Sample raw function usage
70+
71+
torch.compiler.set_stance("force_eager")
72+
print(foo(inp)) # not compiled, prints -1
73+
torch.compiler.set_stance("default")
74+
75+
print(foo(inp)) # compiled, prints 1
76+
77+
######################################################################
78+
# ``torch.compile`` stance can only be changed **outside** of any ``torch.compile`` region. Attempts
79+
# to do otherwise will result in an error.
80+
81+
82+
@torch.compile
83+
def baz(x):
84+
# error!
85+
with torch.compiler.set_stance("force_eager"):
86+
return x + 1
87+
88+
89+
try:
90+
baz(inp)
91+
except Exception as e:
92+
print(e)
93+
94+
95+
@torch.compiler.set_stance("force_eager")
96+
def inner(x):
97+
return x + 1
98+
99+
100+
@torch.compile
101+
def outer(x):
102+
# error!
103+
return inner(x)
104+
105+
106+
try:
107+
outer(inp)
108+
except Exception as e:
109+
print(e)
110+
111+
######################################################################
112+
# Other stances include:
113+
# - ``"default"``: The default stance, used for normal compilation.
114+
# - ``"eager_on_recompile"``: Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used.
115+
# - ``"fail_on_recompile"``: Raise an error when recompiling a function.
116+
#
117+
# See the ``torch.compiler.set_stance`` `doc page <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__
118+
# for more stances and options. More stances/options may also be added in the future.
119+
120+
######################################################################
121+
# Examples
122+
# --------
123+
124+
######################################################################
125+
# Preventing recompilation
126+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
127+
#
128+
# Some models do not expect any recompilations - for example, you may always have inputs with the same shape.
129+
# Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases.
130+
# The ``"fail_on_recompilation"`` stance can be used for this.
131+
132+
133+
@torch.compile
134+
def my_big_model(x):
135+
return torch.relu(x)
136+
137+
138+
# first compilation
139+
my_big_model(torch.randn(3))
140+
141+
with torch.compiler.set_stance("fail_on_recompile"):
142+
my_big_model(torch.randn(3)) # no recompilation - OK
143+
try:
144+
my_big_model(torch.randn(4)) # recompilation - error
145+
except Exception as e:
146+
print(e)
147+
148+
######################################################################
149+
# If erroring out is too disruptive, we can use ``"eager_on_recompile"`` instead,
150+
# which will cause ``torch.compile`` to fall back to eager instead of erroring out.
151+
# This may be useful if we don't expect recompilations to happen frequently, but
152+
# when one is required, we'd rather pay the cost of running eagerly over the cost of recompilation.
153+
154+
155+
@torch.compile
156+
def my_huge_model(x):
157+
if torch.compiler.is_compiling():
158+
return x + 1
159+
else:
160+
return x - 1
161+
162+
163+
# first compilation
164+
print(my_huge_model(torch.zeros(3))) # 1
165+
166+
with torch.compiler.set_stance("eager_on_recompile"):
167+
print(my_huge_model(torch.zeros(3))) # 1
168+
print(my_huge_model(torch.zeros(4))) # -1
169+
print(my_huge_model(torch.zeros(3))) # 1
170+
171+
172+
######################################################################
173+
# Measuring performance gains
174+
# ===========================
175+
#
176+
# ``torch.compiler.set_stance`` can be used to compare eager vs. compiled performance
177+
# without having to define a separate eager model.
178+
179+
180+
# Returns the result of running `fn()` and the time it took for `fn()` to run,
181+
# in seconds. We use CUDA events and synchronization for the most accurate
182+
# measurements.
183+
def timed(fn):
184+
start = torch.cuda.Event(enable_timing=True)
185+
end = torch.cuda.Event(enable_timing=True)
186+
start.record()
187+
result = fn()
188+
end.record()
189+
torch.cuda.synchronize()
190+
return result, start.elapsed_time(end) / 1000
191+
192+
193+
@torch.compile
194+
def my_gigantic_model(x, y):
195+
x = x @ y
196+
x = x @ y
197+
x = x @ y
198+
return x
199+
200+
201+
inps = torch.randn(5, 5), torch.randn(5, 5)
202+
203+
with torch.compiler.set_stance("force_eager"):
204+
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
205+
206+
# warmups
207+
for _ in range(3):
208+
my_gigantic_model(*inps)
209+
210+
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
211+
212+
213+
######################################################################
214+
# Crashing sooner
215+
# ===============
216+
#
217+
# Running an eager iteration first before a compiled iteration using the ``"force_eager"`` stance
218+
# can help us to catch errors unrelated to ``torch.compile`` before attempting a very long compile.
219+
220+
221+
@torch.compile
222+
def my_humongous_model(x):
223+
return torch.sin(x, x)
224+
225+
226+
try:
227+
with torch.compiler.set_stance("force_eager"):
228+
print(my_humongous_model(torch.randn(3)))
229+
# this call to the compiled model won't run
230+
print(my_humongous_model(torch.randn(3)))
231+
except Exception as e:
232+
print(e)
233+
234+
########################################
235+
# Conclusion
236+
# --------------
237+
# In this recipe, we have learned how to use the ``torch.compiler.set_stance`` API
238+
# to modify the behavior of ``torch.compile`` across different calls to a model
239+
# without needing to reapply it. The recipe demonstrates using
240+
# ``torch.compiler.set_stance`` as a decorator, context manager, or raw function
241+
# to control compilation stances like ``force_eager``, ``default``,
242+
# ``eager_on_recompile``, and "fail_on_recompile."
243+
#
244+
# For more information, see: `torch.compiler.set_stance API documentation <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__.

0 commit comments

Comments
 (0)