Skip to content

Commit b4aee8e

Browse files
committed
address comments
1 parent 9a28c7a commit b4aee8e

File tree

1 file changed

+94
-59
lines changed

1 file changed

+94
-59
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 94 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,25 @@
77
"""
88

99
######################################################################
10-
# ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
11-
# to be run on high performance environments.
10+
# :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
11+
# static and standardized model representations, intended
12+
# to be run on different (i.e. Python-less) environments.
1213
#
13-
# ``torch.export`` is built using the components of ``torch.compile``,
14-
# so it may be helpful to familiarize yourself with ``torch.compile``.
15-
# For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
16-
#
17-
# This tutorial focuses on using ``torch.export`` to extract
14+
# In this tutorial, you will learn how to use :func:`torch.export` to extract
1815
# `ExportedProgram`s (i.e. single-graph representations) from PyTorch programs.
16+
# We also detail some considerations/modifications that you may need
17+
# to make in order to make your model compatible with ``torch.export``.
1918
#
20-
# **Contents**
21-
#
22-
# - Exporting a PyTorch model using ``torch.export``
23-
# - Comparison to ``torch.compile``
24-
# - Control Flow Ops
25-
# - Constraints
26-
# - Custom Ops
27-
# - ExportDB
28-
# - Conclusion
19+
# .. contents::
20+
# :local:
2921

3022
######################################################################
3123
# Exporting a PyTorch model using ``torch.export``
3224
# ------------------------------------------------
3325
#
34-
# ``torch.export`` takes in a callable (including ``torch.nn.Module``s),
26+
# ``torch.export`` takes in a callable (including ``torch.nn.Module`` s),
3527
# a tuple of positional arguments, and optionally (not shown in the example below),
36-
# a dictionary of keyword arguments.
28+
# a dictionary of keyword arguments and a list of constraints (covered later).
3729

3830
import torch
3931
from torch.export import export
@@ -51,16 +43,18 @@ def forward(self, x, y):
5143
print(type(exported_mod))
5244

5345
######################################################################
54-
# ``torch.export`` returns an ``ExportedProgram``, which is not a ``torch.nn.Module``,
55-
# but can still be ran as a function:
46+
# ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``,
47+
# but it can still be run as a function:
5648

5749
print(exported_mod(torch.randn(8, 100), torch.randn(8, 100)))
5850

5951
######################################################################
60-
# ``ExportedProgram`` has some attributes that are of interest.
61-
# The ``graph`` attribute is an FX graph traced from the function we exported,
62-
# that is, the computation graph of all PyTorch operations.
52+
# Let's review some attributes of ``ExportedProgram`` that are of interest.
53+
#
54+
# The ``graph`` attribute is an `FX graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__
55+
# traced from the function we exported, that is, the computation graph of all PyTorch operations.
6356
# The FX graph has some important properties:
57+
#
6458
# - The operations are "ATen-level" operations.
6559
# - The graph is "functionalized", meaning that no operations are mutations.
6660
#
@@ -73,12 +67,14 @@ def forward(self, x, y):
7367
exported_mod.graph_module.print_readable()
7468

7569
######################################################################
76-
# The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``)
77-
# and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)``
78-
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate).
79-
80-
######################################################################
70+
# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
71+
# and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``
72+
# is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.
73+
# Future uses of input to the original mutating ``relu`` op are replaced by the additional new output
74+
# of the replacement non-mutating ``relu`` op.
75+
#
8176
# Other attributes of interest in ``ExportedProgram`` include:
77+
#
8278
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
8379
# - ``range_constraints`` and ``equality_constraints`` -- Constraints, covered later
8480

@@ -87,13 +83,15 @@ def forward(self, x, y):
8783
######################################################################
8884
# Comparison to ``torch.compile``
8985
# -------------------------------
86+
#
9087
# Although ``torch.export`` is built on top of the ``torch.compile``
9188
# components, the key limitation of ``torch.export`` is that it does not
9289
# support graph breaks. This is because handling graph breaks involves interpreting
9390
# the unsupported operation with default Python evaluation, which is incompatible
9491
# with the export use case.
9592
#
96-
# A graph break is necessary in cases such as:
93+
# A graph break is necessary in the following cases:
94+
#
9795
# - data-dependent control flow
9896

9997
def bad1(x):
@@ -119,9 +117,8 @@ def bad2(x):
119117
except Exception:
120118
tb.print_exc()
121119

122-
123120
######################################################################
124-
# - calling unsupported functions (e.g. many builtins)
121+
# - calling unsupported functions (such as many built-in functions)
125122

126123
def bad3(x):
127124
x = x + 1
@@ -151,58 +148,75 @@ def bad4(x):
151148
######################################################################
152149
# Control Flow Ops
153150
# ----------------
151+
# .. warning::
152+
#
153+
# ``cond`` is a prototype feature in PyTorch, included as a part of the ``torch.export`` release.
154+
# Future changes may break backwards compatibility.
155+
# Please look forward to a more stable implementation in a future version of PyTorch.
156+
#
154157
# ``torch.export`` actually does support data-dependent control flow.
155158
# But these need to be expressed using control flow ops. For example,
156159
# we can fix the control flow example above using the ``cond`` op, like so:
160+
# <!-- TODO link to docs about cond when it is out -->
157161

158-
from functorch.experimental import control_flow
162+
from functorch.experimental.control_flow import cond
159163

160164
def bad1_fixed(x):
161165
def true_fn(x):
162166
return torch.sin(x)
163167
def false_fn(x):
164168
return torch.cos(x)
165-
return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
169+
return cond(x.sum() > 0, true_fn, false_fn, [x])
166170

167171
exported_bad1_fixed = export(bad1_fixed, (torch.randn(3, 3),))
168172
print(exported_bad1_fixed(torch.ones(3, 3)))
169173
print(exported_bad1_fixed(-torch.ones(3, 3)))
170174

171175
######################################################################
172-
# There are some limitations one should be aware of:
176+
# There are limitations to ``cond`` that one should be aware of:
177+
#
173178
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
174179
# - The operands (i.e. ``[x]``) must be tensors.
175180
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
176-
# operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.)
177-
# - Branch functions cannot mutate inputs or globals
181+
# operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
182+
# - Branch functions cannot mutate input or global variables.
178183
# - Branch functions cannot access closure variables, except for ``self`` if the function is
179184
# defined in the scope of a method.
180185

186+
# <!-- NOTE map is not documented at the moment
187+
181188
######################################################################
182189
# We can also use ``map``, which applies a function across the first dimension
183190
# of the first tensor argument.
184191

185-
from functorch.experimental.control_flow import map
192+
# from functorch.experimental.control_flow import map
193+
194+
# def map_example(xs):
195+
# def map_fn(x, const):
196+
# def true_fn(x):
197+
# return x + const
198+
# def false_fn(x):
199+
# return x - const
200+
# return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
201+
# return control_flow.map(map_fn, xs, torch.tensor([2.0]))
186202

187-
def map_example(xs):
188-
def map_fn(x, const):
189-
def true_fn(x):
190-
return x + const
191-
def false_fn(x):
192-
return x - const
193-
return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
194-
return control_flow.map(map_fn, xs, torch.tensor([2.0]))
203+
# exported_map_example= export(map_example, (torch.randn(4, 3),))
204+
# inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
205+
# print(exported_map_example(inp))
195206

196-
exported_map_example= export(map_example, (torch.randn(4, 3),))
197-
inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
198-
print(exported_map_example(inp))
207+
# -->
199208

200209
######################################################################
201210
# Constraints
202211
# -----------
203-
# Ops can have different specializations for different tensor shapes, so
204-
# ``ExportedProgram``s uses constraints on tensor shapes in order to ensure
205-
# correctness with other inputs.
212+
# .. warning::
213+
#
214+
# The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release.
215+
# Backwards compatibility is not guaranteed. We anticipate releasing a more stable constraints API in the future.
216+
#
217+
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
218+
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
219+
# example inputs given to the initial ``torch.export`` call.
206220
# If we try to run the first ``ExportedProgram`` example with a tensor
207221
# with a different shape, we get an error:
208222

@@ -212,15 +226,19 @@ def false_fn(x):
212226
tb.print_exc()
213227

214228
######################################################################
215-
# By default, ``torch.export`` requires all tensors to have the same shape
216-
# as the example inputs, but we can modify the ``torch.export`` call to
229+
# We can modify the ``torch.export`` call to
217230
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to
218231
# express shape constraints manually.
219232
#
220-
# We can use ``dynamic_dim`` to remove a dimension's constraints, or to
221-
# manually provide an upper or lower bound. In the example below, our input
233+
# <!-- TODO link to doc of dynamic_dim when it is available -->
234+
# Using ``dynamic_dim`` on a tensor's dimension marks it as dynamic (i.e. unconstrained), and
235+
# we can provide additional upper and lower bound shape constraints.
236+
# The first argument of ``dynamic_dim`` is the tensor variable we wish
237+
# to specify a dimension constraint for. The second argument specifies
238+
# the dimension of the first argument the constraint applies to.
239+
# In the example below, our input
222240
# ``inp1`` has an unconstrained first dimension, but the size of the second
223-
# dimension must be in the interval (1, 18].
241+
# dimension must be in the interval (3, 18].
224242

225243
from torch.export import dynamic_dim
226244

@@ -250,6 +268,20 @@ def constraints_example1(x):
250268
except Exception:
251269
tb.print_exc()
252270

271+
######################################################################
272+
# Note that if our inputs to ``torch.export`` do not satisfy the constraints,
273+
# we get an error.
274+
275+
constraints1_bad = [
276+
dynamic_dim(inp1, 0),
277+
10 < dynamic_dim(inp1, 1),
278+
dynamic_dim(inp1, 1) <= 18,
279+
]
280+
try:
281+
export(constraints_example1, (inp1,), constraints=constraints1_bad)
282+
except Exception:
283+
tb.print_exc()
284+
253285
######################################################################
254286
# We can also use ``dynamic_dim`` to enforce expected equalities between
255287
# dimensions, for example, in matrix multiplication:
@@ -343,6 +375,7 @@ def specify_constraints(x, y):
343375
# If a bound is not provided, then it is assumed to be unbounded. ``constrain_as_size``
344376
# is similar to ``constrain_as_value``, except that it should be used on integer values that
345377
# will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
378+
# many operations have special behavior for tensors with a shape value of 0 or 1.
346379

347380
from torch.export import constrain_as_size, constrain_as_value
348381

@@ -382,7 +415,9 @@ def constraints_example5(x, y):
382415
# and may change without notice.
383416
#
384417
# Currently, the steps to register a custom op for use by ``torch.export`` are:
385-
# - Define the custom op using ``torch.library`` as with any other custom op
418+
#
419+
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
420+
# as with any other custom op
386421

387422
from torch.library import Library, impl
388423

@@ -395,7 +430,7 @@ def custom_op(x):
395430
print("custom_op called!")
396431
return torch.relu(x)
397432

398-
# - Define a ``Meta`` implementation of the custom op that returns an empty
433+
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
399434
# tensor with the same shape as the expected output
400435

401436
@impl(m, "custom_op", "Meta")
@@ -444,7 +479,7 @@ def cond_predicate(x):
444479
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
445480
"""
446481
pred = x.dim() > 2 and x.shape[2] > 10
447-
return control_flow.cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
482+
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
448483

449484
# More generally, ExportDB can be used as a reference when one of the following occurs:
450485
# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features

0 commit comments

Comments
 (0)