Skip to content

Commit 24d1c3a

Browse files
committed
fix some formatting
1 parent b4aee8e commit 24d1c3a

File tree

1 file changed

+99
-57
lines changed

1 file changed

+99
-57
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,42 @@
77
"""
88

99
######################################################################
10-
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
11-
# static and standardized model representations, intended
10+
# :func:`torch.export` is the PyTorch 2 way to export PyTorch models into
11+
# standardized model representations, intended
1212
# to be run on different (i.e. Python-less) environments.
1313
#
1414
# In this tutorial, you will learn how to use :func:`torch.export` to extract
15-
# `ExportedProgram`s (i.e. single-graph representations) from PyTorch programs.
15+
# ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs.
1616
# We also detail some considerations/modifications that you may need
1717
# to make in order to make your model compatible with ``torch.export``.
1818
#
19+
# **Contents**
1920
# .. contents::
2021
# :local:
2122

2223
######################################################################
23-
# Exporting a PyTorch model using ``torch.export``
24-
# ------------------------------------------------
24+
# Basic Usage
25+
# -----------
26+
#
27+
# ``torch.export`` extracts single-graph representations from PyTorch programs
28+
# by tracing the target function, given example inputs.
29+
#
30+
# The signature of ``torch.export`` is:
2531
#
26-
# ``torch.export`` takes in a callable (including ``torch.nn.Module`` s),
27-
# a tuple of positional arguments, and optionally (not shown in the example below),
28-
# a dictionary of keyword arguments and a list of constraints (covered later).
32+
# .. code:: python
33+
#
34+
# export(
35+
# f: Callable,
36+
# args: Tuple[Any, ...],
37+
# kwargs: Optional[Dict[str, Any]] = None,
38+
# *,
39+
# constraints: Optional[List[Constraint]] = None
40+
# ) -> ExportedProgram
41+
#
42+
# ``torch.export`` traces the tensor computation graph from calling ``f(*args, **kwargs)``
43+
# and wraps it in an ``ExportedProgram``, which can be serialized or executed later with
44+
# different inputs. Note that while the output ``ExportedGraph`` is callable, it is not a
45+
# ``torch.nn.Module``. We will detail the ``constraints`` argument later in the tutorial.
2946

3047
import torch
3148
from torch.export import export
@@ -41,11 +58,6 @@ def forward(self, x, y):
4158
mod = MyModule()
4259
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
4360
print(type(exported_mod))
44-
45-
######################################################################
46-
# ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``,
47-
# but it can still be run as a function:
48-
4961
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
5062

5163
######################################################################
@@ -79,18 +91,25 @@ def forward(self, x, y):
7991
# - ``range_constraints`` and ``equality_constraints`` -- Constraints, covered later
8092

8193
print(exported_mod.graph_signature)
94+
print(exported_mod.range_constraints)
95+
print(exported_mod.equality_constraints)
96+
97+
######################################################################
98+
# See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__
99+
# for more details.
82100

83101
######################################################################
84-
# Comparison to ``torch.compile``
85-
# -------------------------------
102+
# Graph Breaks
103+
# ------------
86104
#
87-
# Although ``torch.export`` is built on top of the ``torch.compile``
88-
# components, the key limitation of ``torch.export`` is that it does not
105+
# Although ``torch.export`` shares components with ``torch.compile``,
106+
# the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
89107
# support graph breaks. This is because handling graph breaks involves interpreting
90108
# the unsupported operation with default Python evaluation, which is incompatible
91-
# with the export use case.
109+
# with the export use case. Therefore, in order to make your model code compatible
110+
# with ``torch.export``, you will need to modify your code to remove graph breaks.
92111
#
93-
# A graph break is necessary in the following cases:
112+
# A graph break is necessary in cases such as:
94113
#
95114
# - data-dependent control flow
96115

@@ -145,9 +164,14 @@ def bad4(x):
145164
except Exception:
146165
tb.print_exc()
147166

167+
######################################################################
168+
# The sections below demonstrate some ways you can modify your code
169+
# in order to remove graph breaks.
170+
148171
######################################################################
149172
# Control Flow Ops
150173
# ----------------
174+
#
151175
# .. warning::
152176
#
153177
# ``cond`` is a prototype feature in PyTorch, included as a part of the ``torch.export`` release.
@@ -157,7 +181,9 @@ def bad4(x):
157181
# ``torch.export`` actually does support data-dependent control flow.
158182
# But these need to be expressed using control flow ops. For example,
159183
# we can fix the control flow example above using the ``cond`` op, like so:
160-
# <!-- TODO link to docs about cond when it is out -->
184+
185+
# ..
186+
# [TODO] link to docs about cond when it is out
161187

162188
from functorch.experimental.control_flow import cond
163189

@@ -178,37 +204,36 @@ def false_fn(x):
178204
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
179205
# - The operands (i.e. ``[x]``) must be tensors.
180206
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
181-
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
207+
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
182208
# - Branch functions cannot mutate input or global variables.
183209
# - Branch functions cannot access closure variables, except for ``self`` if the function is
184210
# defined in the scope of a method.
185211

186-
# <!-- NOTE map is not documented at the moment
187-
188212
######################################################################
189-
# We can also use ``map``, which applies a function across the first dimension
190-
# of the first tensor argument.
191-
192-
# from functorch.experimental.control_flow import map
193-
194-
# def map_example(xs):
195-
# def map_fn(x, const):
196-
# def true_fn(x):
197-
# return x + const
198-
# def false_fn(x):
199-
# return x - const
200-
# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
201-
# return control_flow.map(map_fn, xs, torch.tensor([2.0]))
202-
203-
# exported_map_example= export(map_example, (torch.randn(4, 3),))
204-
# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
205-
# print(exported_map_example(inp))
206-
207-
# -->
213+
# ..
214+
# [NOTE] map is not documented at the moment
215+
# We can also use ``map``, which applies a function across the first dimension
216+
# of the first tensor argument.
217+
#
218+
# from functorch.experimental.control_flow import map
219+
#
220+
# def map_example(xs):
221+
# def map_fn(x, const):
222+
# def true_fn(x):
223+
# return x + const
224+
# def false_fn(x):
225+
# return x - const
226+
# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
227+
# return control_flow.map(map_fn, xs, torch.tensor([2.0]))
228+
#
229+
# exported_map_example= export(map_example, (torch.randn(4, 3),))
230+
# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
231+
# print(exported_map_example(inp))
208232

209233
######################################################################
210234
# Constraints
211235
# -----------
236+
#
212237
# .. warning::
213238
#
214239
# The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release.
@@ -230,7 +255,9 @@ def false_fn(x):
230255
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to
231256
# express shape constraints manually.
232257
#
233-
# <!-- TODO link to doc of dynamic_dim when it is available -->
258+
# ..
259+
# [TODO] link to doc of dynamic_dim when it is available
260+
#
234261
# Using ``dynamic_dim`` on a tensor's dimension marks it as dynamic (i.e. unconstrained), and
235262
# we can provide additional upper and lower bound shape constraints.
236263
# The first argument of ``dynamic_dim`` is the tensor variable we wish
@@ -269,8 +296,8 @@ def constraints_example1(x):
269296
tb.print_exc()
270297

271298
######################################################################
272-
# Note that if our inputs to ``torch.export`` do not satisfy the constraints,
273-
# we get an error.
299+
# Note that if our example inputs to ``torch.export`` do not satisfy the constraints,
300+
# then we get an error.
274301

275302
constraints1_bad = [
276303
dynamic_dim(inp1, 0),
@@ -309,7 +336,9 @@ def constraints_example2(x, y):
309336

310337
######################################################################
311338
# We can actually use ``torch.export`` to guide us as to which constraints
312-
# are necessary. We can do this by relaxing all constraints and letting ``torch.export``
339+
# are necessary. We can do this by relaxing all constraints (recall that if we
340+
# do not provide constraints for a dimension, the default behavior is to constrain
341+
# to the exact shape value of the example input) and letting ``torch.export``
313342
# error out.
314343

315344
inp4 = torch.randn(8, 16)
@@ -372,10 +401,7 @@ def specify_constraints(x, y):
372401
# We can also constrain on individual values in the source code itself using
373402
# ``constrain_as_value`` and ``constrain_as_size``. ``constrain_as_value`` specifies
374403
# that a given integer value is expected to fall within the provided minimum/maximum bounds (inclusive).
375-
# If a bound is not provided, then it is assumed to be unbounded. ``constrain_as_size``
376-
# is similar to ``constrain_as_value``, except that it should be used on integer values that
377-
# will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
378-
# many operations have special behavior for tensors with a shape value of 0 or 1.
404+
# If a bound is not provided, then it is assumed to be unbounded.
379405

380406
from torch.export import constrain_as_size, constrain_as_value
381407

@@ -393,6 +419,11 @@ def constraints_example4(x, y):
393419
except Exception:
394420
tb.print_exc()
395421

422+
######################################################################
423+
# ``constrain_as_size`` is similar to ``constrain_as_value``, except that it should be used on integer values that
424+
# will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
425+
# many operations have special behavior for tensors with a shape value of 0 or 1.
426+
396427
def constraints_example5(x, y):
397428
b = y.item()
398429
constrain_as_size(b)
@@ -409,15 +440,18 @@ def constraints_example5(x, y):
409440
######################################################################
410441
# Custom Ops
411442
# ----------
443+
#
412444
# ``torch.export`` can export PyTorch programs with custom operators.
413445
#
414-
# NOTE: the API for registering custom ops is still under active development
415-
# and may change without notice.
446+
# .. warning::
447+
#
448+
# The API for registering custom ops is still under active development
449+
# and may change without notice.
416450
#
417451
# Currently, the steps to register a custom op for use by ``torch.export`` are:
418452
#
419453
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
420-
# as with any other custom op
454+
# as with any other custom op
421455

422456
from torch.library import Library, impl
423457

@@ -430,13 +464,15 @@ def custom_op(x):
430464
print("custom_op called!")
431465
return torch.relu(x)
432466

467+
######################################################################
433468
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
434-
# tensor with the same shape as the expected output
469+
# tensor with the same shape as the expected output
435470

436471
@impl(m, "custom_op", "Meta")
437472
def custom_op_meta(x):
438473
return torch.empty_like(x)
439474

475+
######################################################################
440476
# - Call the custom op from the code you want to export using ``torch.ops``
441477

442478
def custom_op_example(x):
@@ -445,24 +481,27 @@ def custom_op_example(x):
445481
x = torch.cos(x)
446482
return x
447483

484+
######################################################################
448485
# - Export the code as before
449486

450487
exported_custom_op_example = export(custom_op_example, (torch.randn(3, 3),))
451488
exported_custom_op_example.graph_module.print_readable()
452489
print(exported_custom_op_example(torch.randn(3, 3)))
453490

491+
######################################################################
454492
# Note in the above outputs that the custom op is included in the exported graph.
455493
# And when we call the exported graph as a function, the original custom op is called,
456494
# as evidenced by the ``print`` call.
457495

458496
######################################################################
459497
# ExportDB
460498
# --------
499+
#
461500
# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement,
462501
# there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to
463502
# rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting
464503
# if-statements using ``cond``.
465-
504+
#
466505
# `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ is the standard reference that documents
467506
# supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each
468507
# of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``.
@@ -481,17 +520,20 @@ def cond_predicate(x):
481520
pred = x.dim() > 2 and x.shape[2] > 10
482521
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
483522

523+
######################################################################
484524
# More generally, ExportDB can be used as a reference when one of the following occurs:
525+
#
485526
# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features
486527
# and you want to know if ``torch.export`` covers that feature.
487528
# 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it.
488-
529+
#
489530
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
490531
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.
491532

492533
######################################################################
493534
# Conclusion
494535
# ----------
495-
# We introduced ``torch.export``, the new PyTorch 2.0 way to export single computation
536+
#
537+
# We introduced ``torch.export``, the new PyTorch 2 way to export single computation
496538
# graphs from PyTorch programs. In particular, we demonstrate several code modifications
497539
# and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.

0 commit comments

Comments
 (0)