7
7
"""
8
8
9
9
######################################################################
10
- # ``torch.export`` is the PyTorch 2.0 way to export PyTorch models intended
11
- # to be run on high performance environments.
10
+ # :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
11
+ # static and standardized model representations, intended
12
+ # to be run on different (i.e. Python-less) environments.
12
13
#
13
- # ``torch.export`` is built using the components of ``torch.compile``,
14
- # so it may be helpful to familiarize yourself with ``torch.compile``.
15
- # For an introduction to ``torch.compile``, see the ` ``torch.compile`` tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
16
- #
17
- # This tutorial focuses on using ``torch.export`` to extract
14
+ # In this tutorial, you will learn how to use :func:`torch.export` to extract
18
15
# `ExportedProgram`s (i.e. single-graph representations) from PyTorch programs.
16
+ # We also detail some considerations/modifications that you may need
17
+ # to make in order to make your model compatible with ``torch.export``.
19
18
#
20
- # **Contents**
21
- #
22
- # - Exporting a PyTorch model using ``torch.export``
23
- # - Comparison to ``torch.compile``
24
- # - Control Flow Ops
25
- # - Constraints
26
- # - Custom Ops
27
- # - ExportDB
28
- # - Conclusion
19
+ # .. contents::
20
+ # :local:
29
21
30
22
######################################################################
31
23
# Exporting a PyTorch model using ``torch.export``
32
24
# ------------------------------------------------
33
25
#
34
- # ``torch.export`` takes in a callable (including ``torch.nn.Module``s),
26
+ # ``torch.export`` takes in a callable (including ``torch.nn.Module`` s),
35
27
# a tuple of positional arguments, and optionally (not shown in the example below),
36
- # a dictionary of keyword arguments.
28
+ # a dictionary of keyword arguments and a list of constraints (covered later) .
37
29
38
30
import torch
39
31
from torch .export import export
@@ -51,16 +43,18 @@ def forward(self, x, y):
51
43
print (type (exported_mod ))
52
44
53
45
######################################################################
54
- # ``torch.export`` returns an ``ExportedProgram``, which is not a ``torch.nn.Module``,
55
- # but can still be ran as a function:
46
+ # ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``,
47
+ # but it can still be run as a function:
56
48
57
49
print (exported_mod (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
58
50
59
51
######################################################################
60
- # ``ExportedProgram`` has some attributes that are of interest.
61
- # The ``graph`` attribute is an FX graph traced from the function we exported,
62
- # that is, the computation graph of all PyTorch operations.
52
+ # Let's review some attributes of ``ExportedProgram`` that are of interest.
53
+ #
54
+ # The ``graph`` attribute is an `FX graph <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__
55
+ # traced from the function we exported, that is, the computation graph of all PyTorch operations.
63
56
# The FX graph has some important properties:
57
+ #
64
58
# - The operations are "ATen-level" operations.
65
59
# - The graph is "functionalized", meaning that no operations are mutations.
66
60
#
@@ -73,12 +67,14 @@ def forward(self, x, y):
73
67
exported_mod .graph_module .print_readable ()
74
68
75
69
######################################################################
76
- # The printed code shows that FX graph only contains ATen-level ops (i.e. ``torch.ops.aten``)
77
- # and that mutations were removed (e.g. the mutating op ``torch.nn.functional.relu(..., inplace=True)``
78
- # is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate).
79
-
80
- ######################################################################
70
+ # The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
71
+ # and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``
72
+ # is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.
73
+ # Future uses of input to the original mutating ``relu`` op are replaced by the additional new output
74
+ # of the replacement non-mutating ``relu`` op.
75
+ #
81
76
# Other attributes of interest in ``ExportedProgram`` include:
77
+ #
82
78
# - ``graph_signature`` -- the inputs, outputs, parameters, buffers, etc. of the exported graph.
83
79
# - ``range_constraints`` and ``equality_constraints`` -- Constraints, covered later
84
80
@@ -87,13 +83,15 @@ def forward(self, x, y):
87
83
######################################################################
88
84
# Comparison to ``torch.compile``
89
85
# -------------------------------
86
+ #
90
87
# Although ``torch.export`` is built on top of the ``torch.compile``
91
88
# components, the key limitation of ``torch.export`` is that it does not
92
89
# support graph breaks. This is because handling graph breaks involves interpreting
93
90
# the unsupported operation with default Python evaluation, which is incompatible
94
91
# with the export use case.
95
92
#
96
- # A graph break is necessary in cases such as:
93
+ # A graph break is necessary in the following cases:
94
+ #
97
95
# - data-dependent control flow
98
96
99
97
def bad1 (x ):
@@ -119,9 +117,8 @@ def bad2(x):
119
117
except Exception :
120
118
tb .print_exc ()
121
119
122
-
123
120
######################################################################
124
- # - calling unsupported functions (e.g. many builtins )
121
+ # - calling unsupported functions (such as many built-in functions )
125
122
126
123
def bad3 (x ):
127
124
x = x + 1
@@ -151,58 +148,75 @@ def bad4(x):
151
148
######################################################################
152
149
# Control Flow Ops
153
150
# ----------------
151
+ # .. warning::
152
+ #
153
+ # ``cond`` is a prototype feature in PyTorch, included as a part of the ``torch.export`` release.
154
+ # Future changes may break backwards compatibility.
155
+ # Please look forward to a more stable implementation in a future version of PyTorch.
156
+ #
154
157
# ``torch.export`` actually does support data-dependent control flow.
155
158
# But these need to be expressed using control flow ops. For example,
156
159
# we can fix the control flow example above using the ``cond`` op, like so:
160
+ # <!-- TODO link to docs about cond when it is out -->
157
161
158
- from functorch .experimental import control_flow
162
+ from functorch .experimental . control_flow import cond
159
163
160
164
def bad1_fixed (x ):
161
165
def true_fn (x ):
162
166
return torch .sin (x )
163
167
def false_fn (x ):
164
168
return torch .cos (x )
165
- return control_flow . cond (x .sum () > 0 , true_fn , false_fn , [x ])
169
+ return cond (x .sum () > 0 , true_fn , false_fn , [x ])
166
170
167
171
exported_bad1_fixed = export (bad1_fixed , (torch .randn (3 , 3 ),))
168
172
print (exported_bad1_fixed (torch .ones (3 , 3 )))
169
173
print (exported_bad1_fixed (- torch .ones (3 , 3 )))
170
174
171
175
######################################################################
172
- # There are some limitations one should be aware of:
176
+ # There are limitations to ``cond`` that one should be aware of:
177
+ #
173
178
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
174
179
# - The operands (i.e. ``[x]``) must be tensors.
175
180
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
176
- # operands and they must both return a single tensor with the same metadata (e.g. dtype, shape, etc.)
177
- # - Branch functions cannot mutate inputs or globals
181
+ # operands and they must both return a single tensor with the same metadata (for example, `` dtype``, `` shape`` , etc.).
182
+ # - Branch functions cannot mutate input or global variables.
178
183
# - Branch functions cannot access closure variables, except for ``self`` if the function is
179
184
# defined in the scope of a method.
180
185
186
+ # <!-- NOTE map is not documented at the moment
187
+
181
188
######################################################################
182
189
# We can also use ``map``, which applies a function across the first dimension
183
190
# of the first tensor argument.
184
191
185
- from functorch .experimental .control_flow import map
192
+ # from functorch.experimental.control_flow import map
193
+
194
+ # def map_example(xs):
195
+ # def map_fn(x, const):
196
+ # def true_fn(x):
197
+ # return x + const
198
+ # def false_fn(x):
199
+ # return x - const
200
+ # return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
201
+ # return control_flow.map(map_fn, xs, torch.tensor([2.0]))
186
202
187
- def map_example (xs ):
188
- def map_fn (x , const ):
189
- def true_fn (x ):
190
- return x + const
191
- def false_fn (x ):
192
- return x - const
193
- return control_flow .cond (x .sum () > 0 , true_fn , false_fn , [x ])
194
- return control_flow .map (map_fn , xs , torch .tensor ([2.0 ]))
203
+ # exported_map_example= export(map_example, (torch.randn(4, 3),))
204
+ # inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
205
+ # print(exported_map_example(inp))
195
206
196
- exported_map_example = export (map_example , (torch .randn (4 , 3 ),))
197
- inp = torch .cat ((torch .ones (2 , 3 ), - torch .ones (2 , 3 )))
198
- print (exported_map_example (inp ))
207
+ # -->
199
208
200
209
######################################################################
201
210
# Constraints
202
211
# -----------
203
- # Ops can have different specializations for different tensor shapes, so
204
- # ``ExportedProgram``s uses constraints on tensor shapes in order to ensure
205
- # correctness with other inputs.
212
+ # .. warning::
213
+ #
214
+ # The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release.
215
+ # Backwards compatibility is not guaranteed. We anticipate releasing a more stable constraints API in the future.
216
+ #
217
+ # Ops can have different specializations/behaviors for different tensor shapes, so by default,
218
+ # ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
219
+ # example inputs given to the initial ``torch.export`` call.
206
220
# If we try to run the first ``ExportedProgram`` example with a tensor
207
221
# with a different shape, we get an error:
208
222
@@ -212,15 +226,19 @@ def false_fn(x):
212
226
tb .print_exc ()
213
227
214
228
######################################################################
215
- # By default, ``torch.export`` requires all tensors to have the same shape
216
- # as the example inputs, but we can modify the ``torch.export`` call to
229
+ # We can modify the ``torch.export`` call to
217
230
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to
218
231
# express shape constraints manually.
219
232
#
220
- # We can use ``dynamic_dim`` to remove a dimension's constraints, or to
221
- # manually provide an upper or lower bound. In the example below, our input
233
+ # <!-- TODO link to doc of dynamic_dim when it is available -->
234
+ # Using ``dynamic_dim`` on a tensor's dimension marks it as dynamic (i.e. unconstrained), and
235
+ # we can provide additional upper and lower bound shape constraints.
236
+ # The first argument of ``dynamic_dim`` is the tensor variable we wish
237
+ # to specify a dimension constraint for. The second argument specifies
238
+ # the dimension of the first argument the constraint applies to.
239
+ # In the example below, our input
222
240
# ``inp1`` has an unconstrained first dimension, but the size of the second
223
- # dimension must be in the interval (1 , 18].
241
+ # dimension must be in the interval (3 , 18].
224
242
225
243
from torch .export import dynamic_dim
226
244
@@ -250,6 +268,20 @@ def constraints_example1(x):
250
268
except Exception :
251
269
tb .print_exc ()
252
270
271
+ ######################################################################
272
+ # Note that if our inputs to ``torch.export`` do not satisfy the constraints,
273
+ # we get an error.
274
+
275
+ constraints1_bad = [
276
+ dynamic_dim (inp1 , 0 ),
277
+ 10 < dynamic_dim (inp1 , 1 ),
278
+ dynamic_dim (inp1 , 1 ) <= 18 ,
279
+ ]
280
+ try :
281
+ export (constraints_example1 , (inp1 ,), constraints = constraints1_bad )
282
+ except Exception :
283
+ tb .print_exc ()
284
+
253
285
######################################################################
254
286
# We can also use ``dynamic_dim`` to enforce expected equalities between
255
287
# dimensions, for example, in matrix multiplication:
@@ -343,6 +375,7 @@ def specify_constraints(x, y):
343
375
# If a bound is not provided, then it is assumed to be unbounded. ``constrain_as_size``
344
376
# is similar to ``constrain_as_value``, except that it should be used on integer values that
345
377
# will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
378
+ # many operations have special behavior for tensors with a shape value of 0 or 1.
346
379
347
380
from torch .export import constrain_as_size , constrain_as_value
348
381
@@ -382,7 +415,9 @@ def constraints_example5(x, y):
382
415
# and may change without notice.
383
416
#
384
417
# Currently, the steps to register a custom op for use by ``torch.export`` are:
385
- # - Define the custom op using ``torch.library`` as with any other custom op
418
+ #
419
+ # - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
420
+ # as with any other custom op
386
421
387
422
from torch .library import Library , impl
388
423
@@ -395,7 +430,7 @@ def custom_op(x):
395
430
print ("custom_op called!" )
396
431
return torch .relu (x )
397
432
398
- # - Define a ``Meta`` implementation of the custom op that returns an empty
433
+ # - Define a ``" Meta" `` implementation of the custom op that returns an empty
399
434
# tensor with the same shape as the expected output
400
435
401
436
@impl (m , "custom_op" , "Meta" )
@@ -444,7 +479,7 @@ def cond_predicate(x):
444
479
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
445
480
"""
446
481
pred = x .dim () > 2 and x .shape [2 ] > 10
447
- return control_flow . cond (pred , lambda x : x .cos (), lambda y : y .sin (), [x ])
482
+ return cond (pred , lambda x : x .cos (), lambda y : y .sin (), [x ])
448
483
449
484
# More generally, ExportDB can be used as a reference when one of the following occurs:
450
485
# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features
0 commit comments