From 9ff6d8947126a11bed1d9b9c870e1ae8a3397894 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 21 Jan 2025 16:56:46 -0800 Subject: [PATCH 1/5] init test branch --- intermediate_source/torch_export_tutorial.py | 182 +++++++++++++++++++ 1 file changed, 182 insertions(+) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 9acacf53629..26de0278851 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -629,6 +629,188 @@ def forward(self, x, y): "bool_val": None, } +###################################################################### +# Data-dependent errors +# --------------------- +# +# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``. +# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward. +# +# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties +# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may +# struggle with user code that relies on data values. In short, if the compiler requires a concrete, specialized value that is dependent on tensor data in order to proceed, it will error, complaining that +# FakeTensor tracing isn't providing the information required. +# +# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. +# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here - +# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts +# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed. +# +# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching. +# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts". +# +# Let's see how these show up in exported programs, with this example: + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + b = y.tolist() + return b + [a] + +inps = ( + torch.tensor(1), + torch.tensor([2, 3]), +) +ep = export(Foo(), inps) +print(ep) + +###################################################################### +# The result is that 3 unbacked symbols (prefixed with ``u``) are allocated and returned; 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the +# ``Range constraints`` field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols. + +###################################################################### +# Guards, torch._check() +# ^^^^^^^^^^^^^^^^^^^^^^ +# +# But the case above is easy to export, because the compiler doesn't need the concrete values of the unbacked symbols for anything. All that's relevant is that the return values are unbacked symbols. +# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered: + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + if a // 2 >= 5: + return y + 2 + else: + return y * 5 + +###################################################################### +# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5``. Because the hint isn't available, the expression ``a // 2 >= 5`` +# can't be concretely evaluated, and export errors out with ``Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)``. +# +# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options: +# +# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or by guarding undesired branches with ``torch.compiler.is_compiling()``. +# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()``, keeping both branches alive. +# +# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors; +# there are data-dependent errors that do not involve control-flow. +# +# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler regarding properties of symbols. +# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is deferred as a runtime assert, and any symbol properties that follow from the expression being true +# inform the symbolic shapes subsystem (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to describe properties that are generally true for these symbols via +# ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code. +# +# For example in the model above, inserting ``torch._check(a >= 10)`` tells the compiler that ``return y + 2`` can always be traced, and ``torch._check(a == 4)`` tells it to trace ``return y * 5``. +# See what happens when we re-export this model. + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + torch._check(a >= 10) + torch._check(a <= 60) + if a // 2 >= 5: + return y + 2 + else: + return y * 5 + +inps = ( + torch.tensor(32), + torch.randn(4), +) +ep = export(Foo(), inps) +print(ep) + +###################################################################### +# Export succeeds, and note from the ``Range constraints`` field that the ``torch._check()`` calls have informed the compiler, giving ``u0`` a range of ``[10, 60]``. +# +# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are accepted: +# +# 1. Equality with simple, non-data-dependent expressions: ``torch._check()`` calls that communicate expressions like ``u0 == s0 + 4`` or ``u0 == 5``. +# 2. Range refinement: calls that provide lower or upper bounds for symbols refine symbol ranges. +# 3. Some basic reasoning around more complicated expressions: for example, a complicated expression like ``torch._check(a ** 2 - 3 * a <= 10)`` will get you past a guard with the same expression. +# +# As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion +# prevails while manual specialization & ``torch.cond()`` do not: + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + return y[a] + +inps = ( + torch.tensor(32), + torch.randn(60), +) +export(Foo(), inps) + +###################################################################### +# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with +# ``Could not guard on data-dependent expression -u0 > 60``, implying that the compiler doesn't know if this is a valid indexing operation; +# if the value of ``x`` is out-of-bounds for ``y`` or not. Here, manual specialization is too prohibitive, and ``torch.cond()`` has no place. +# Instead, informing the compiler of ``u0``'s range is sufficient: + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + torch._check(a >= 0) + torch._check(a <= y.shape[0]) + return y[a] + +inps = ( + torch.tensor(32), + torch.randn(60), +) +ep = export(Foo(), inps) +print(ep) + +###################################################################### +# Specialized values +# ^^^^^^^^^^^^^^^^^^ +# +# Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value +# while tracing. This looks something like ``Could not extract specialized integer from data-dependent expression``, and is analogous to +# the previous class of errors; if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise +# with evaluating concrete boolean values. +# +# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, list comprehension +# in Python requires an ``int()`` cast on the size of the list: + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + b = torch.cat([y for y in range(a)], dim=0) + return b + int(a) + +inps = ( + torch.tensor(32), + torch.randn(60), +) +export(Foo(), inps, strict=False) + +###################################################################### +# In this case, some basic options you have are: +# +# 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement. +# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (e.g. with ``torch._check(a == 60)``). +# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. Therefore, the following rewrite avoids this error. + +class Foo(torch.nn.Module): + def forward(self, x, y): + a = x.item() + b = y.unsqueeze(0).repeat(a, 1) + return b + a + +inps = ( + torch.tensor(32), + torch.randn(60), +) +ep = export(Foo(), inps, strict=False) +print(ep) + +###################################################################### +# Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters. +# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors `. + ###################################################################### # Custom Ops # ---------- From ea1cbe676441bc266564e4087e0042450236fe22 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 22 Jan 2025 15:38:05 -0800 Subject: [PATCH 2/5] edits --- en-wordlist.txt | 4 + intermediate_source/torch_export_tutorial.py | 82 ++++++++++---------- 2 files changed, 46 insertions(+), 40 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index 7c2ed6c398c..7ad997e4ca7 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -81,6 +81,7 @@ FX FX's FairSeq Fastpath +FakeTensor FFN FloydHub FloydHub's @@ -238,6 +239,7 @@ SoTA Sohn Spacy SwiGLU +SymInt TCP THP TIAToolbox @@ -368,6 +370,7 @@ downsample downsamples dropdown dtensor +dtype duration elementwise embeddings @@ -615,6 +618,7 @@ triton uint UX umap +unbacked uncomment uncommented underflowing diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 26de0278851..a49f6704269 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -633,23 +633,23 @@ def forward(self, x, y): # Data-dependent errors # --------------------- # -# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``. -# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward. -# -# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties -# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may -# struggle with user code that relies on data values. In short, if the compiler requires a concrete, specialized value that is dependent on tensor data in order to proceed, it will error, complaining that +# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression". +# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties +# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle +# with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that # FakeTensor tracing isn't providing the information required. # -# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. -# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here - -# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts -# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed. +# Data-depdenent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. +# How are these values represented in the exported program? In the `Constraints/Dynamic Shapes `_ +# section, we talked about allocating symbols to represent dynamic input dimensions. +# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", +# in contrast to the "backed" symbols/SymInts allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence/absence of a "hint" for the symbol: +# a concrete value backing the symbol, that can inform the compiler on how to proceed. # -# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching. -# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts". +# In the input shape symbol case (backed SymInts), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. +# For data-dependent values, the symbols are taken from FakeTensor "data" during tracing, and so the compiler doesn't know the actual values (hints) that these symbols would take on. # -# Let's see how these show up in exported programs, with this example: +# Let's see how these show up in exported programs: class Foo(torch.nn.Module): def forward(self, x, y): @@ -665,14 +665,16 @@ def forward(self, x, y): print(ep) ###################################################################### -# The result is that 3 unbacked symbols (prefixed with ``u``) are allocated and returned; 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the -# ``Range constraints`` field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols. +# The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned: +# 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. +# Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols, +# since we literally have no information on what these values are - they don't represent sizes, so don't necessarily have positive values. ###################################################################### # Guards, torch._check() # ^^^^^^^^^^^^^^^^^^^^^^ # -# But the case above is easy to export, because the compiler doesn't need the concrete values of the unbacked symbols for anything. All that's relevant is that the return values are unbacked symbols. +# But the case above is easy to export, because the concrete values of these symbols aren't used in any compiler decision-making; all that's relevant is that the return values are unbacked symbols. # The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered: class Foo(torch.nn.Module): @@ -684,23 +686,23 @@ def forward(self, x, y): return y * 5 ###################################################################### -# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5``. Because the hint isn't available, the expression ``a // 2 >= 5`` -# can't be concretely evaluated, and export errors out with ``Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)``. +# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output. +# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5`` (unhinted)". # # So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options: # -# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or by guarding undesired branches with ``torch.compiler.is_compiling()``. -# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()``, keeping both branches alive. +# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using ``torch.compiler.is_compiling()`` to guard what's traced at compile-time. +# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()`` so we don't specialize on a branch. # -# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors; -# there are data-dependent errors that do not involve control-flow. +# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors. +# As we will see, there are data-dependent errors that do not involve control-flow. # -# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler regarding properties of symbols. -# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is deferred as a runtime assert, and any symbol properties that follow from the expression being true -# inform the symbolic shapes subsystem (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to describe properties that are generally true for these symbols via +# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler on properties of symbols. +# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is sent to the symbolic shapes subsystem for reasoning, and any symbol properties that follow from the expression being true, +# are stored as symbol properties (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to communicate properties that are generally true for these symbols via # ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code. # -# For example in the model above, inserting ``torch._check(a >= 10)`` tells the compiler that ``return y + 2`` can always be traced, and ``torch._check(a == 4)`` tells it to trace ``return y * 5``. +# For example in the model above, inserting ``torch._check(a >= 10)`` would tell the compiler that ``y + 2`` can always be returned, and ``torch._check(a == 4)`` tells it to return ``y * 5``. # See what happens when we re-export this model. class Foo(torch.nn.Module): @@ -721,13 +723,13 @@ def forward(self, x, y): print(ep) ###################################################################### -# Export succeeds, and note from the ``Range constraints`` field that the ``torch._check()`` calls have informed the compiler, giving ``u0`` a range of ``[10, 60]``. +# Export succeeds, and note from the range constraints field that ``u0`` takes on a range of ``[10, 60]``. # -# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are accepted: +# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are generally true: # -# 1. Equality with simple, non-data-dependent expressions: ``torch._check()`` calls that communicate expressions like ``u0 == s0 + 4`` or ``u0 == 5``. -# 2. Range refinement: calls that provide lower or upper bounds for symbols refine symbol ranges. -# 3. Some basic reasoning around more complicated expressions: for example, a complicated expression like ``torch._check(a ** 2 - 3 * a <= 10)`` will get you past a guard with the same expression. +# 1. Equality with non-data-dependent expressions: ``torch._check()`` calls that communicate equalities like ``u0 == s0 + 4`` or ``u0 == 5``. +# 2. Range refinement: calls that provide lower or upper bounds for symbols, like the above. +# 3. Some basic reasoning around more complicated expressions: inserting ``torch._check(a < 4)`` will typically tell the compiler that ``a >= 4`` is false. Checks on complex expressions like ``torch._check(a ** 2 - 3 * a <= 10)`` will typically get you past identical guards. # # As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion # prevails while manual specialization & ``torch.cond()`` do not: @@ -745,7 +747,7 @@ def forward(self, x, y): ###################################################################### # Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with -# ``Could not guard on data-dependent expression -u0 > 60``, implying that the compiler doesn't know if this is a valid indexing operation; +# "Could not guard on data-dependent expression ``-u0 > 60``", implying that the compiler doesn't know if this is a valid indexing operation - # if the value of ``x`` is out-of-bounds for ``y`` or not. Here, manual specialization is too prohibitive, and ``torch.cond()`` has no place. # Instead, informing the compiler of ``u0``'s range is sufficient: @@ -768,12 +770,12 @@ def forward(self, x, y): # ^^^^^^^^^^^^^^^^^^ # # Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value -# while tracing. This looks something like ``Could not extract specialized integer from data-dependent expression``, and is analogous to -# the previous class of errors; if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise +# while tracing. This looks something like "Could not extract specialized integer from data-dependent expression", and is analogous to +# the previous class of errors - if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise # with evaluating concrete boolean values. # -# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, list comprehension -# in Python requires an ``int()`` cast on the size of the list: +# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, this list comprehension +# has a `range()` call that implicitly does an ``int()`` cast on the size of the list: class Foo(torch.nn.Module): def forward(self, x, y): @@ -788,11 +790,11 @@ def forward(self, x, y): export(Foo(), inps, strict=False) ###################################################################### -# In this case, some basic options you have are: +# For these errors, some basic options you have are: # # 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement. -# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (e.g. with ``torch._check(a == 60)``). -# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. Therefore, the following rewrite avoids this error. +# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (with ``torch._check(a == 60)``). +# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. The following rewrite avoids data-dependent errors: class Foo(torch.nn.Module): def forward(self, x, y): @@ -809,7 +811,7 @@ def forward(self, x, y): ###################################################################### # Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters. -# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors `. +# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors `_. ###################################################################### # Custom Ops From a0b9760204dbbee6ceed01610c587583b7c2aab7 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 22 Jan 2025 16:15:55 -0800 Subject: [PATCH 3/5] spellcheck --- en-wordlist.txt | 4 +++- intermediate_source/torch_export_tutorial.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index 7ad997e4ca7..4d589d7db57 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -82,6 +82,7 @@ FX's FairSeq Fastpath FakeTensor +FakeTensors FFN FloydHub FloydHub's @@ -240,6 +241,7 @@ Sohn Spacy SwiGLU SymInt +SymInts TCP THP TIAToolbox @@ -371,6 +373,7 @@ downsamples dropdown dtensor dtype +dtypes duration elementwise embeddings @@ -655,7 +658,6 @@ RecSys TorchRec sharding TBE -dtype EBC sharder hyperoptimized diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index a49f6704269..021e5fc1b12 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -639,7 +639,7 @@ def forward(self, x, y): # with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that # FakeTensor tracing isn't providing the information required. # -# Data-depdenent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. +# Data-dependent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. # How are these values represented in the exported program? In the `Constraints/Dynamic Shapes `_ # section, we talked about allocating symbols to represent dynamic input dimensions. # The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", @@ -687,7 +687,7 @@ def forward(self, x, y): ###################################################################### # Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output. -# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5`` (unhinted)". +# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5 (unhinted)``". # # So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options: # From e9e9caf60037dca6974f9e652836b1702060e634 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 22 Jan 2025 16:20:14 -0800 Subject: [PATCH 4/5] Update torch_export_tutorial.py --- intermediate_source/torch_export_tutorial.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 021e5fc1b12..46f95b69092 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -633,7 +633,7 @@ def forward(self, x, y): # Data-dependent errors # --------------------- # -# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression". +# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or "Could not extract specialized integer from data-dependent expression". # These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties # (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle # with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that From 2f9dd136c129f1ccd9628b476aef814fd86d9dae Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 23 Jan 2025 09:30:38 -0800 Subject: [PATCH 5/5] address nits --- en-wordlist.txt | 2 -- intermediate_source/torch_export_tutorial.py | 25 ++++++++++---------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/en-wordlist.txt b/en-wordlist.txt index 4d589d7db57..b56df45df0c 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -240,8 +240,6 @@ SoTA Sohn Spacy SwiGLU -SymInt -SymInts TCP THP TIAToolbox diff --git a/intermediate_source/torch_export_tutorial.py b/intermediate_source/torch_export_tutorial.py index 021e5fc1b12..c992eefa9fc 100644 --- a/intermediate_source/torch_export_tutorial.py +++ b/intermediate_source/torch_export_tutorial.py @@ -634,19 +634,19 @@ def forward(self, x, y): # --------------------- # # While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression". -# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties -# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle -# with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that -# FakeTensor tracing isn't providing the information required. +# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. While these have equivalent symbolic properties +# (e.g. sizes, strides, dtypes), they diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may be +# unable to out-of-the-box compile parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, +# complaining that the value is not available. # # Data-dependent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. # How are these values represented in the exported program? In the `Constraints/Dynamic Shapes `_ # section, we talked about allocating symbols to represent dynamic input dimensions. -# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", -# in contrast to the "backed" symbols/SymInts allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence/absence of a "hint" for the symbol: -# a concrete value backing the symbol, that can inform the compiler on how to proceed. +# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols, +# in contrast to the "backed" symbols allocated for input dimensions. The `"backed/unbacked" `_ +# nomenclature refers to the presence/absence of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler on how to proceed. # -# In the input shape symbol case (backed SymInts), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. +# In the input shape symbol case (backed symbols), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. # For data-dependent values, the symbols are taken from FakeTensor "data" during tracing, and so the compiler doesn't know the actual values (hints) that these symbols would take on. # # Let's see how these show up in exported programs: @@ -668,14 +668,14 @@ def forward(self, x, y): # The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned: # 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. # Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols, -# since we literally have no information on what these values are - they don't represent sizes, so don't necessarily have positive values. +# since we have no information on what these values are - they don't represent sizes, so don't necessarily have positive values. ###################################################################### # Guards, torch._check() # ^^^^^^^^^^^^^^^^^^^^^^ # # But the case above is easy to export, because the concrete values of these symbols aren't used in any compiler decision-making; all that's relevant is that the return values are unbacked symbols. -# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered: +# The data-dependent errors highlighted in this section are cases like the following, where `data-dependent guards `_ are encountered: class Foo(torch.nn.Module): def forward(self, x, y): @@ -689,7 +689,7 @@ def forward(self, x, y): # Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output. # Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5 (unhinted)``". # -# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options: +# So how do we export this toy model? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here are some basic options: # # 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using ``torch.compiler.is_compiling()`` to guard what's traced at compile-time. # 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()`` so we don't specialize on a branch. @@ -811,7 +811,8 @@ def forward(self, x, y): ###################################################################### # Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters. -# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors `_. +# For more in-depth guides, please refer to the `Export Programming Model `_, +# or `Dealing with GuardOnDataDependentSymNode errors `_. ###################################################################### # Custom Ops