|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +Dynamic Compilation Control with ``torch.compiler.set_stance`` |
| 5 | +========================================================================= |
| 6 | +**Author:** `William Wen <https://github.com/williamwen42>`_ |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# ``torch.compiler.set_stance`` is a ``torch.compiler`` API that |
| 11 | +# enables you to change the behavior of ``torch.compile`` across different |
| 12 | +# calls to your model without having to reapply ``torch.compile`` to your model. |
| 13 | +# |
| 14 | +# This recipe provides some examples on how to use ``torch.compiler.set_stance``. |
| 15 | +# |
| 16 | +# |
| 17 | +# .. contents:: |
| 18 | +# :local: |
| 19 | +# |
| 20 | +# Prerequisites |
| 21 | +# --------------- |
| 22 | +# |
| 23 | +# - ``torch >= 2.6`` |
| 24 | + |
| 25 | +###################################################################### |
| 26 | +# Description |
| 27 | +# ----------- |
| 28 | +# ``torch.compile.set_stance`` can be used as a decorator, context manager, or raw function |
| 29 | +# to change the behavior of ``torch.compile`` across different calls to your model. |
| 30 | +# |
| 31 | +# In the example below, the ``"force_eager"`` stance ignores all ``torch.compile`` directives. |
| 32 | + |
| 33 | +import torch |
| 34 | + |
| 35 | + |
| 36 | +@torch.compile |
| 37 | +def foo(x): |
| 38 | + if torch.compiler.is_compiling(): |
| 39 | + # torch.compile is active |
| 40 | + return x + 1 |
| 41 | + else: |
| 42 | + # torch.compile is not active |
| 43 | + return x - 1 |
| 44 | + |
| 45 | + |
| 46 | +inp = torch.zeros(3) |
| 47 | + |
| 48 | +print(foo(inp)) # compiled, prints 1 |
| 49 | + |
| 50 | +###################################################################### |
| 51 | +# Sample decorator usage |
| 52 | + |
| 53 | + |
| 54 | +@torch.compiler.set_stance("force_eager") |
| 55 | +def bar(x): |
| 56 | + # force disable the compiler |
| 57 | + return foo(x) |
| 58 | + |
| 59 | + |
| 60 | +print(bar(inp)) # not compiled, prints -1 |
| 61 | + |
| 62 | +###################################################################### |
| 63 | +# Sample context manager usage |
| 64 | + |
| 65 | +with torch.compiler.set_stance("force_eager"): |
| 66 | + print(foo(inp)) # not compiled, prints -1 |
| 67 | + |
| 68 | +###################################################################### |
| 69 | +# Sample raw function usage |
| 70 | + |
| 71 | +torch.compiler.set_stance("force_eager") |
| 72 | +print(foo(inp)) # not compiled, prints -1 |
| 73 | +torch.compiler.set_stance("default") |
| 74 | + |
| 75 | +print(foo(inp)) # compiled, prints 1 |
| 76 | + |
| 77 | +###################################################################### |
| 78 | +# ``torch.compile`` stance can only be changed **outside** of any ``torch.compile`` region. Attempts |
| 79 | +# to do otherwise will result in an error. |
| 80 | + |
| 81 | + |
| 82 | +@torch.compile |
| 83 | +def baz(x): |
| 84 | + # error! |
| 85 | + with torch.compiler.set_stance("force_eager"): |
| 86 | + return x + 1 |
| 87 | + |
| 88 | + |
| 89 | +try: |
| 90 | + baz(inp) |
| 91 | +except Exception as e: |
| 92 | + print(e) |
| 93 | + |
| 94 | + |
| 95 | +@torch.compiler.set_stance("force_eager") |
| 96 | +def inner(x): |
| 97 | + return x + 1 |
| 98 | + |
| 99 | + |
| 100 | +@torch.compile |
| 101 | +def outer(x): |
| 102 | + # error! |
| 103 | + return inner(x) |
| 104 | + |
| 105 | + |
| 106 | +try: |
| 107 | + outer(inp) |
| 108 | +except Exception as e: |
| 109 | + print(e) |
| 110 | + |
| 111 | +###################################################################### |
| 112 | +# Other stances include: |
| 113 | +# - ``"default"``: The default stance, used for normal compilation. |
| 114 | +# - ``"eager_on_recompile"``: Run code eagerly when a recompile is necessary. If there is cached compiled code valid for the input, it will still be used. |
| 115 | +# - ``"fail_on_recompile"``: Raise an error when recompiling a function. |
| 116 | +# |
| 117 | +# See the ``torch.compiler.set_stance`` `doc page <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__ |
| 118 | +# for more stances and options. More stances/options may also be added in the future. |
| 119 | + |
| 120 | +###################################################################### |
| 121 | +# Examples |
| 122 | +# -------- |
| 123 | + |
| 124 | +###################################################################### |
| 125 | +# Preventing recompilation |
| 126 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 127 | +# |
| 128 | +# Some models do not expect any recompilations - for example, you may always have inputs with the same shape. |
| 129 | +# Since recompilations may be expensive, we may wish to error out when we attempt to recompile so we can detect and fix recompilation cases. |
| 130 | +# The ``"fail_on_recompilation"`` stance can be used for this. |
| 131 | + |
| 132 | + |
| 133 | +@torch.compile |
| 134 | +def my_big_model(x): |
| 135 | + return torch.relu(x) |
| 136 | + |
| 137 | + |
| 138 | +# first compilation |
| 139 | +my_big_model(torch.randn(3)) |
| 140 | + |
| 141 | +with torch.compiler.set_stance("fail_on_recompile"): |
| 142 | + my_big_model(torch.randn(3)) # no recompilation - OK |
| 143 | + try: |
| 144 | + my_big_model(torch.randn(4)) # recompilation - error |
| 145 | + except Exception as e: |
| 146 | + print(e) |
| 147 | + |
| 148 | +###################################################################### |
| 149 | +# If erroring out is too disruptive, we can use ``"eager_on_recompile"`` instead, |
| 150 | +# which will cause ``torch.compile`` to fall back to eager instead of erroring out. |
| 151 | +# This may be useful if we don't expect recompilations to happen frequently, but |
| 152 | +# when one is required, we'd rather pay the cost of running eagerly over the cost of recompilation. |
| 153 | + |
| 154 | + |
| 155 | +@torch.compile |
| 156 | +def my_huge_model(x): |
| 157 | + if torch.compiler.is_compiling(): |
| 158 | + return x + 1 |
| 159 | + else: |
| 160 | + return x - 1 |
| 161 | + |
| 162 | + |
| 163 | +# first compilation |
| 164 | +print(my_huge_model(torch.zeros(3))) # 1 |
| 165 | + |
| 166 | +with torch.compiler.set_stance("eager_on_recompile"): |
| 167 | + print(my_huge_model(torch.zeros(3))) # 1 |
| 168 | + print(my_huge_model(torch.zeros(4))) # -1 |
| 169 | + print(my_huge_model(torch.zeros(3))) # 1 |
| 170 | + |
| 171 | + |
| 172 | +###################################################################### |
| 173 | +# Measuring performance gains |
| 174 | +# =========================== |
| 175 | +# |
| 176 | +# ``torch.compiler.set_stance`` can be used to compare eager vs. compiled performance |
| 177 | +# without having to define a separate eager model. |
| 178 | + |
| 179 | + |
| 180 | +# Returns the result of running `fn()` and the time it took for `fn()` to run, |
| 181 | +# in seconds. We use CUDA events and synchronization for the most accurate |
| 182 | +# measurements. |
| 183 | +def timed(fn): |
| 184 | + start = torch.cuda.Event(enable_timing=True) |
| 185 | + end = torch.cuda.Event(enable_timing=True) |
| 186 | + start.record() |
| 187 | + result = fn() |
| 188 | + end.record() |
| 189 | + torch.cuda.synchronize() |
| 190 | + return result, start.elapsed_time(end) / 1000 |
| 191 | + |
| 192 | + |
| 193 | +@torch.compile |
| 194 | +def my_gigantic_model(x, y): |
| 195 | + x = x @ y |
| 196 | + x = x @ y |
| 197 | + x = x @ y |
| 198 | + return x |
| 199 | + |
| 200 | + |
| 201 | +inps = torch.randn(5, 5), torch.randn(5, 5) |
| 202 | + |
| 203 | +with torch.compiler.set_stance("force_eager"): |
| 204 | + print("eager:", timed(lambda: my_gigantic_model(*inps))[1]) |
| 205 | + |
| 206 | +# warmups |
| 207 | +for _ in range(3): |
| 208 | + my_gigantic_model(*inps) |
| 209 | + |
| 210 | +print("compiled:", timed(lambda: my_gigantic_model(*inps))[1]) |
| 211 | + |
| 212 | + |
| 213 | +###################################################################### |
| 214 | +# Crashing sooner |
| 215 | +# =============== |
| 216 | +# |
| 217 | +# Running an eager iteration first before a compiled iteration using the ``"force_eager"`` stance |
| 218 | +# can help us to catch errors unrelated to ``torch.compile`` before attempting a very long compile. |
| 219 | + |
| 220 | + |
| 221 | +@torch.compile |
| 222 | +def my_humongous_model(x): |
| 223 | + return torch.sin(x, x) |
| 224 | + |
| 225 | + |
| 226 | +try: |
| 227 | + with torch.compiler.set_stance("force_eager"): |
| 228 | + print(my_humongous_model(torch.randn(3))) |
| 229 | + # this call to the compiled model won't run |
| 230 | + print(my_humongous_model(torch.randn(3))) |
| 231 | +except Exception as e: |
| 232 | + print(e) |
| 233 | + |
| 234 | +######################################## |
| 235 | +# Conclusion |
| 236 | +# -------------- |
| 237 | +# In this recipe, we have learned how to use the ``torch.compiler.set_stance`` API |
| 238 | +# to modify the behavior of ``torch.compile`` across different calls to a model |
| 239 | +# without needing to reapply it. The recipe demonstrates using |
| 240 | +# ``torch.compiler.set_stance`` as a decorator, context manager, or raw function |
| 241 | +# to control compilation stances like ``force_eager``, ``default``, |
| 242 | +# ``eager_on_recompile``, and "fail_on_recompile." |
| 243 | +# |
| 244 | +# For more information, see: `torch.compiler.set_stance API documentation <https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance>`__. |
0 commit comments