Skip to content

Commit fd67881

Browse files
committed
changes
1 parent 9edb9fb commit fd67881

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def forward(self, x):
190190
# about safety, but not all Python code is supported, causing these graph
191191
# breaks.
192192
#
193-
# To address this issue, in PyTorch 2.5, we introduced a new mode of
193+
# To address this issue, in PyTorch 2.3, we introduced a new mode of
194194
# exporting called non-strict mode, where we trace through the program using the
195195
# Python interpreter executing it exactly as it would in eager mode, allowing us
196196
# to skip over unsupported Python features. This is done through adding a
@@ -306,7 +306,7 @@ def false_fn(x):
306306
#
307307
# This section covers dynamic behavior and representation of exported programs. Dynamic behavior is
308308
# subjective to the particular model being exported, so for the most part of this tutorial, we'll focus
309-
# on this particular toy model (with the sample input shapes annotated):
309+
# on this particular toy model (with the resulting tensor shapes annotated):
310310

311311
class DynamicModel(torch.nn.Module):
312312
def __init__(self):
@@ -320,14 +320,14 @@ def forward(
320320
y: torch.Tensor, # [8, 4]
321321
z: torch.Tensor, # [32]
322322
):
323-
x0 = x + y # output shape: [8, 4]
323+
x0 = x + y # [8, 4]
324324
x1 = self.l(w) # [6, 3]
325325
x2 = x0.flatten() # [32]
326326
x3 = x2 + z # [32]
327327
return x1, x3
328328

329329
######################################################################
330-
# By default, ``torch.export`` produces a static program. One clear consequence of this is that at runtime,
330+
# By default, ``torch.export`` produces a static program. One consequence of this is that at runtime,
331331
# the program won't work on inputs with different shapes, even if they're valid in eager mode.
332332

333333
w = torch.randn(6, 5)
@@ -339,6 +339,9 @@ def forward(
339339
model(w, x, torch.randn(3, 4), torch.randn(12))
340340
ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
341341

342+
# Basic concepts: symbols and guards
343+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
344+
342345
######################################################################
343346
# To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with
344347
# dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified
@@ -357,7 +360,8 @@ def forward(
357360
######################################################################
358361
# Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails,
359362
# and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is
360-
# allocated, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the
363+
# `allocated <https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes>`,
364+
# taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the
361365
# 0/1 specialization section).
362366
#
363367
# Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit
@@ -383,7 +387,7 @@ def forward(
383387
):
384388
x0 = x + y # guard: s2 == s4
385389
x1 = self.l(w) # guard: s1 == 5
386-
x2 = x0.flatten()
390+
x2 = x0.flatten() # no guard added here
387391
x3 = x2 + z # guard: s3 * s4 == s5
388392
return x1, x3
389393

@@ -425,26 +429,28 @@ def forward(
425429
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
426430

427431
######################################################################
428-
# This spits out quite a handful, even with this simple toy model. But looking through the logs we can see the lines relevant
429-
# to what we described above; e.g. the allocation of symbols:
432+
# This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end
433+
# to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above;
434+
# e.g. the allocation of symbols:
430435

431436
"""
432-
I1210 16:20:19.720000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
433-
I1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
434-
V1210 16:20:19.722000 3417744 torch/fx/experimental/symbolic_shapes.py:6535] [1/0] runtime_assert True == True [statically known]
435-
I1210 16:20:19.727000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
436-
I1210 16:20:19.729000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
437-
I1210 16:20:19.731000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
438-
I1210 16:20:19.734000 3417744 torch/fx/experimental/symbolic_shapes.py:4404] [1/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
437+
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
438+
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
439+
runtime_assert True == True [statically known]
440+
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
441+
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
442+
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
443+
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
439444
"""
440445

441446
######################################################################
442-
# Or the guards emitted:
447+
# The lines with `create_symbol` show when a new symbol has been allocated, and the logs also identify the tensor variable names
448+
# and dimensions they've been allocated for. In other lines we can also see the guards emitted:
443449

444450
"""
445-
I1210 16:20:19.743000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
446-
I1210 16:20:19.754000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
447-
I1210 16:20:19.775000 3417744 torch/fx/experimental/symbolic_shapes.py:6234] [1/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
451+
runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
452+
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
453+
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
448454
"""
449455

450456
######################################################################
@@ -456,14 +462,14 @@ def forward(
456462
#
457463
# ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available:
458464
# ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all
459-
# ways except one: it raises an error when specializing to a constant; designed to maintain dynamism. See for example what happens when a
465+
# ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a
460466
# static guard is emitted on a dynamically-marked dimension:
461467

462468
dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
463469
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
464470

465471
######################################################################
466-
# Static guards also aren't always inherent to the model; they can also come from user-specifications. In fact, a common pitfall leading to shape
472+
# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
467473
# specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is
468474
# raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``:
469475

@@ -473,12 +479,12 @@ def forward(
473479
export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
474480

475481
######################################################################
476-
# Here you might ask why export "specializes"; why we resolve this static/dynamic conflict by going with the static route. The answer is because
482+
# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
477483
# of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile
478484
# treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to
479485
# specialization.
480486
#
481-
# One feature of export is that during tracing, statements like asserts, ``torch._checks()``, and ``if/else`` conditions will also emit guards.
487+
# One feature of export is that during tracing, statements like asserts, ``torch._check()``, and ``if/else`` conditions will also emit guards.
482488
# See what happens when we augment the existing model with such statements:
483489

484490
class DynamicModel(torch.nn.Module):
@@ -516,7 +522,10 @@ def forward(self, w, x, y, z):
516522
# If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch.
517523
# Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches
518524
# alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above.
519-
#
525+
526+
# 0/1 specialization
527+
# ^^^^^^^^^^^^^^^^^^
528+
520529
# Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier.
521530
# The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that
522531
# don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should
@@ -532,6 +541,9 @@ def forward(self, w, x, y, z):
532541
)
533542
ep.module()(torch.randn(2, 4))
534543

544+
# Named Dims
545+
# ^^^^^^^^^^
546+
535547
######################################################################
536548
# So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the
537549
# low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic
@@ -567,6 +579,9 @@ def forward(self, w, x, y, z):
567579
"x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.
568580
}
569581

582+
# Constraint violations, suggested fixes
583+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
584+
570585
######################################################################
571586
# One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing.
572587
# That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between
@@ -594,7 +609,7 @@ def forward(self, x, y):
594609
#
595610
# - ``None`` is a good option for static behavior:
596611
# - ``dynamic_shapes=None`` (default) exports with the entire model being static.
597-
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and alternatively is also required for non-tensor inputs.
612+
# - specifying ``None`` at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs.
598613
# - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``.
599614
# - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification.
600615
#

0 commit comments

Comments
 (0)