Skip to content

Commit ea1cbe6

Browse files
committed
edits
1 parent d6a992b commit ea1cbe6

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

en-wordlist.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ FX
8181
FX's
8282
FairSeq
8383
Fastpath
84+
FakeTensor
8485
FFN
8586
FloydHub
8687
FloydHub's
@@ -238,6 +239,7 @@ SoTA
238239
Sohn
239240
Spacy
240241
SwiGLU
242+
SymInt
241243
TCP
242244
THP
243245
TIAToolbox
@@ -368,6 +370,7 @@ downsample
368370
downsamples
369371
dropdown
370372
dtensor
373+
dtype
371374
duration
372375
elementwise
373376
embeddings
@@ -615,6 +618,7 @@ triton
615618
uint
616619
UX
617620
umap
621+
unbacked
618622
uncomment
619623
uncommented
620624
underflowing

intermediate_source/torch_export_tutorial.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -633,23 +633,23 @@ def forward(self, x, y):
633633
# Data-dependent errors
634634
# ---------------------
635635
#
636-
# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``.
637-
# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward.
638-
#
639-
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
640-
# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
641-
# struggle with user code that relies on data values. In short, if the compiler requires a concrete, specialized value that is dependent on tensor data in order to proceed, it will error, complaining that
636+
# While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression".
637+
# These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. For example, they may have equivalent symbolic properties
638+
# (e.g. sizes, strides, dtypes), but diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may struggle
639+
# with parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that
642640
# FakeTensor tracing isn't providing the information required.
643641
#
644-
# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
645-
# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here -
646-
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
647-
# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed.
642+
# Data-depdenent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
643+
# How are these values represented in the exported program? In the `Constraints/Dynamic Shapes <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_
644+
# section, we talked about allocating symbols to represent dynamic input dimensions.
645+
# The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts",
646+
# in contrast to the "backed" symbols/SymInts allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence/absence of a "hint" for the symbol:
647+
# a concrete value backing the symbol, that can inform the compiler on how to proceed.
648648
#
649-
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
650-
# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
649+
# In the input shape symbol case (backed SymInts), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties.
650+
# For data-dependent values, the symbols are taken from FakeTensor "data" during tracing, and so the compiler doesn't know the actual values (hints) that these symbols would take on.
651651
#
652-
# Let's see how these show up in exported programs, with this example:
652+
# Let's see how these show up in exported programs:
653653

654654
class Foo(torch.nn.Module):
655655
def forward(self, x, y):
@@ -665,14 +665,16 @@ def forward(self, x, y):
665665
print(ep)
666666

667667
######################################################################
668-
# The result is that 3 unbacked symbols (prefixed with ``u``) are allocated and returned; 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the
669-
# ``Range constraints`` field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols.
668+
# The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned:
669+
# 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call.
670+
# Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols,
671+
# since we literally have no information on what these values are - they don't represent sizes, so don't necessarily have positive values.
670672

671673
######################################################################
672674
# Guards, torch._check()
673675
# ^^^^^^^^^^^^^^^^^^^^^^
674676
#
675-
# But the case above is easy to export, because the compiler doesn't need the concrete values of the unbacked symbols for anything. All that's relevant is that the return values are unbacked symbols.
677+
# But the case above is easy to export, because the concrete values of these symbols aren't used in any compiler decision-making; all that's relevant is that the return values are unbacked symbols.
676678
# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:
677679

678680
class Foo(torch.nn.Module):
@@ -684,23 +686,23 @@ def forward(self, x, y):
684686
return y * 5
685687

686688
######################################################################
687-
# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5``. Because the hint isn't available, the expression ``a // 2 >= 5``
688-
# can't be concretely evaluated, and export errors out with ``Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)``.
689+
# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output.
690+
# Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5`` (unhinted)".
689691
#
690692
# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options:
691693
#
692-
# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or by guarding undesired branches with ``torch.compiler.is_compiling()``.
693-
# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()``, keeping both branches alive.
694+
# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using ``torch.compiler.is_compiling()`` to guard what's traced at compile-time.
695+
# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()`` so we don't specialize on a branch.
694696
#
695-
# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors;
696-
# there are data-dependent errors that do not involve control-flow.
697+
# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors.
698+
# As we will see, there are data-dependent errors that do not involve control-flow.
697699
#
698-
# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler regarding properties of symbols.
699-
# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is deferred as a runtime assert, and any symbol properties that follow from the expression being true
700-
# inform the symbolic shapes subsystem (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to describe properties that are generally true for these symbols via
700+
# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler on properties of symbols.
701+
# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is sent to the symbolic shapes subsystem for reasoning, and any symbol properties that follow from the expression being true,
702+
# are stored as symbol properties (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to communicate properties that are generally true for these symbols via
701703
# ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code.
702704
#
703-
# For example in the model above, inserting ``torch._check(a >= 10)`` tells the compiler that ``return y + 2`` can always be traced, and ``torch._check(a == 4)`` tells it to trace ``return y * 5``.
705+
# For example in the model above, inserting ``torch._check(a >= 10)`` would tell the compiler that ``y + 2`` can always be returned, and ``torch._check(a == 4)`` tells it to return ``y * 5``.
704706
# See what happens when we re-export this model.
705707

706708
class Foo(torch.nn.Module):
@@ -721,13 +723,13 @@ def forward(self, x, y):
721723
print(ep)
722724

723725
######################################################################
724-
# Export succeeds, and note from the ``Range constraints`` field that the ``torch._check()`` calls have informed the compiler, giving ``u0`` a range of ``[10, 60]``.
726+
# Export succeeds, and note from the range constraints field that ``u0`` takes on a range of ``[10, 60]``.
725727
#
726-
# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are accepted:
728+
# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are generally true:
727729
#
728-
# 1. Equality with simple, non-data-dependent expressions: ``torch._check()`` calls that communicate expressions like ``u0 == s0 + 4`` or ``u0 == 5``.
729-
# 2. Range refinement: calls that provide lower or upper bounds for symbols refine symbol ranges.
730-
# 3. Some basic reasoning around more complicated expressions: for example, a complicated expression like ``torch._check(a ** 2 - 3 * a <= 10)`` will get you past a guard with the same expression.
730+
# 1. Equality with non-data-dependent expressions: ``torch._check()`` calls that communicate equalities like ``u0 == s0 + 4`` or ``u0 == 5``.
731+
# 2. Range refinement: calls that provide lower or upper bounds for symbols, like the above.
732+
# 3. Some basic reasoning around more complicated expressions: inserting ``torch._check(a < 4)`` will typically tell the compiler that ``a >= 4`` is false. Checks on complex expressions like ``torch._check(a ** 2 - 3 * a <= 10)`` will typically get you past identical guards.
731733
#
732734
# As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion
733735
# prevails while manual specialization & ``torch.cond()`` do not:
@@ -745,7 +747,7 @@ def forward(self, x, y):
745747

746748
######################################################################
747749
# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
748-
# ``Could not guard on data-dependent expression -u0 > 60``, implying that the compiler doesn't know if this is a valid indexing operation;
750+
# "Could not guard on data-dependent expression ``-u0 > 60``", implying that the compiler doesn't know if this is a valid indexing operation -
749751
# if the value of ``x`` is out-of-bounds for ``y`` or not. Here, manual specialization is too prohibitive, and ``torch.cond()`` has no place.
750752
# Instead, informing the compiler of ``u0``'s range is sufficient:
751753

@@ -768,12 +770,12 @@ def forward(self, x, y):
768770
# ^^^^^^^^^^^^^^^^^^
769771
#
770772
# Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value
771-
# while tracing. This looks something like ``Could not extract specialized integer from data-dependent expression``, and is analogous to
772-
# the previous class of errors; if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise
773+
# while tracing. This looks something like "Could not extract specialized integer from data-dependent expression", and is analogous to
774+
# the previous class of errors - if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise
773775
# with evaluating concrete boolean values.
774776
#
775-
# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, list comprehension
776-
# in Python requires an ``int()`` cast on the size of the list:
777+
# This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, this list comprehension
778+
# has a `range()` call that implicitly does an ``int()`` cast on the size of the list:
777779

778780
class Foo(torch.nn.Module):
779781
def forward(self, x, y):
@@ -788,11 +790,11 @@ def forward(self, x, y):
788790
export(Foo(), inps, strict=False)
789791

790792
######################################################################
791-
# In this case, some basic options you have are:
793+
# For these errors, some basic options you have are:
792794
#
793795
# 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement.
794-
# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (e.g. with ``torch._check(a == 60)``).
795-
# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. Therefore, the following rewrite avoids this error.
796+
# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (with ``torch._check(a == 60)``).
797+
# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. The following rewrite avoids data-dependent errors:
796798

797799
class Foo(torch.nn.Module):
798800
def forward(self, x, y):
@@ -809,7 +811,7 @@ def forward(self, x, y):
809811

810812
######################################################################
811813
# Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters.
812-
# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`.
814+
# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`_.
813815

814816
######################################################################
815817
# Custom Ops

0 commit comments

Comments
 (0)