You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# While trying to export models, you have may have encountered errors like ``Could not guard on data-dependent expression`` or ``Could not extract specialized integer from data-dependent expression``.
637
+
# Obscure as they may seem, the reasoning behind their existence, and their resolution, is actually quite straightforward.
638
+
#
639
+
# These errors exist because ``torch.export()`` compiles programs using ``FakeTensors``, which symbolically represent their real tensor counterparts (e.g. they may have the same or equivalent symbolic properties
640
+
# - sizes, strides, dtypes, etc.), but diverge in one major respect: they do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that the compiler may
641
+
# struggle with user code that relies on data values. In short, if the compiler requires a concrete, specialized value that is dependent on tensor data in order to proceed, it will error, complaining that
642
+
# FakeTensor tracing isn't providing the information required.
643
+
#
644
+
# Let's talk about where data-dependent values appear in programs. Common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors.
645
+
# How are these values represented in the exported program? In the ``Constraints/Dynamic Shapes`` section, we talked about allocating symbols to represent dynamic input dimensions, and the same happens here -
646
+
# we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols or "unbacked SymInts", in contrast to the "backed" symbols/SymInts
647
+
# allocated for input dimensions. The "backed/unbacked" nomenclature refers to the presence, or absence, of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler how to proceed.
648
+
#
649
+
# For dynamic input shapes (backed SymInts), these hints are taken from the shapes of the sample inputs provided, which explains why sample input shapes direct the compiler in control-flow branching.
650
+
# On the other hand, data-dependent values are derived from FakeTensors during tracing, and by default lack hints to inform the compiler, hence the name "unbacked symbols/SymInts".
651
+
#
652
+
# Let's see how these show up in exported programs, with this example:
# The result is that 3 unbacked symbols (prefixed with ``u``) are allocated and returned; 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the
669
+
# ``Range constraints`` field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols.
# But the case above is easy to export, because the compiler doesn't need the concrete values of the unbacked symbols for anything. All that's relevant is that the return values are unbacked symbols.
676
+
# The data-dependent errors highlighted in this section are cases like the following, where data-dependent guards are encountered:
# Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5``. Because the hint isn't available, the expression ``a // 2 >= 5``
688
+
# can't be concretely evaluated, and export errors out with ``Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)``.
689
+
#
690
+
# So how do we actually export this? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here's some basic options:
691
+
#
692
+
# 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or by guarding undesired branches with ``torch.compiler.is_compiling()``.
693
+
# 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()``, keeping both branches alive.
694
+
#
695
+
# While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors;
696
+
# there are data-dependent errors that do not involve control-flow.
697
+
#
698
+
# The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler regarding properties of symbols.
699
+
# While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is deferred as a runtime assert, and any symbol properties that follow from the expression being true
700
+
# inform the symbolic shapes subsystem (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to describe properties that are generally true for these symbols via
701
+
# ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code.
702
+
#
703
+
# For example in the model above, inserting ``torch._check(a >= 10)`` tells the compiler that ``return y + 2`` can always be traced, and ``torch._check(a == 4)`` tells it to trace ``return y * 5``.
# Export succeeds, and note from the ``Range constraints`` field that the ``torch._check()`` calls have informed the compiler, giving ``u0`` a range of ``[10, 60]``.
725
+
#
726
+
# So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are accepted:
727
+
#
728
+
# 1. Equality with simple, non-data-dependent expressions: ``torch._check()`` calls that communicate expressions like ``u0 == s0 + 4`` or ``u0 == 5``.
729
+
# 2. Range refinement: calls that provide lower or upper bounds for symbols refine symbol ranges.
730
+
# 3. Some basic reasoning around more complicated expressions: for example, a complicated expression like ``torch._check(a ** 2 - 3 * a <= 10)`` will get you past a guard with the same expression.
731
+
#
732
+
# As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion
733
+
# prevails while manual specialization & ``torch.cond()`` do not:
# 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement.
794
+
# 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (e.g. with ``torch._check(a == 60)``).
795
+
# 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. Therefore, the following rewrite avoids this error.
# Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters.
812
+
# For a more in-depth guide, please refer to `Dealing with GuardOnDataDependentSymNode errors <https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs>`.
0 commit comments