7
7
"""
8
8
9
9
######################################################################
10
- # :func:`torch.export` is the PyTorch 2.0 way to export PyTorch models into
11
- # static and standardized model representations, intended
10
+ # :func:`torch.export` is the PyTorch 2 way to export PyTorch models into
11
+ # standardized model representations, intended
12
12
# to be run on different (i.e. Python-less) environments.
13
13
#
14
14
# In this tutorial, you will learn how to use :func:`torch.export` to extract
15
- # `ExportedProgram`s (i.e. single-graph representations) from PyTorch programs.
15
+ # `` ExportedProgram``' s (i.e. single-graph representations) from PyTorch programs.
16
16
# We also detail some considerations/modifications that you may need
17
17
# to make in order to make your model compatible with ``torch.export``.
18
18
#
19
+ # **Contents**
19
20
# .. contents::
20
21
# :local:
21
22
22
23
######################################################################
23
- # Exporting a PyTorch model using ``torch.export``
24
- # ------------------------------------------------
24
+ # Basic Usage
25
+ # -----------
26
+ #
27
+ # ``torch.export`` extracts single-graph representations from PyTorch programs
28
+ # by tracing the target function, given example inputs.
29
+ #
30
+ # The signature of ``torch.export`` is:
25
31
#
26
- # ``torch.export`` takes in a callable (including ``torch.nn.Module`` s),
27
- # a tuple of positional arguments, and optionally (not shown in the example below),
28
- # a dictionary of keyword arguments and a list of constraints (covered later).
32
+ # .. code:: python
33
+ #
34
+ # export(
35
+ # f: Callable,
36
+ # args: Tuple[Any, ...],
37
+ # kwargs: Optional[Dict[str, Any]] = None,
38
+ # *,
39
+ # constraints: Optional[List[Constraint]] = None
40
+ # ) -> ExportedProgram
41
+ #
42
+ # ``torch.export`` traces the tensor computation graph from calling ``f(*args, **kwargs)``
43
+ # and wraps it in an ``ExportedProgram``, which can be serialized or executed later with
44
+ # different inputs. Note that while the output ``ExportedGraph`` is callable, it is not a
45
+ # ``torch.nn.Module``. We will detail the ``constraints`` argument later in the tutorial.
29
46
30
47
import torch
31
48
from torch .export import export
@@ -41,11 +58,6 @@ def forward(self, x, y):
41
58
mod = MyModule ()
42
59
exported_mod = export (mod , (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
43
60
print (type (exported_mod ))
44
-
45
- ######################################################################
46
- # ``torch.export`` returns an ``ExportedProgram``. It is not a ``torch.nn.Module``,
47
- # but it can still be run as a function:
48
-
49
61
print (exported_mod (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
50
62
51
63
######################################################################
@@ -79,18 +91,25 @@ def forward(self, x, y):
79
91
# - ``range_constraints`` and ``equality_constraints`` -- Constraints, covered later
80
92
81
93
print (exported_mod .graph_signature )
94
+ print (exported_mod .range_constraints )
95
+ print (exported_mod .equality_constraints )
96
+
97
+ ######################################################################
98
+ # See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__
99
+ # for more details.
82
100
83
101
######################################################################
84
- # Comparison to ``torch.compile``
85
- # -------------------------------
102
+ # Graph Breaks
103
+ # ------------
86
104
#
87
- # Although ``torch.export`` is built on top of the ``torch.compile``
88
- # components, the key limitation of ``torch.export`` is that it does not
105
+ # Although ``torch.export`` shares components with ``torch.compile``,
106
+ # the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not
89
107
# support graph breaks. This is because handling graph breaks involves interpreting
90
108
# the unsupported operation with default Python evaluation, which is incompatible
91
- # with the export use case.
109
+ # with the export use case. Therefore, in order to make your model code compatible
110
+ # with ``torch.export``, you will need to modify your code to remove graph breaks.
92
111
#
93
- # A graph break is necessary in the following cases :
112
+ # A graph break is necessary in cases such as :
94
113
#
95
114
# - data-dependent control flow
96
115
@@ -145,9 +164,14 @@ def bad4(x):
145
164
except Exception :
146
165
tb .print_exc ()
147
166
167
+ ######################################################################
168
+ # The sections below demonstrate some ways you can modify your code
169
+ # in order to remove graph breaks.
170
+
148
171
######################################################################
149
172
# Control Flow Ops
150
173
# ----------------
174
+ #
151
175
# .. warning::
152
176
#
153
177
# ``cond`` is a prototype feature in PyTorch, included as a part of the ``torch.export`` release.
@@ -157,7 +181,9 @@ def bad4(x):
157
181
# ``torch.export`` actually does support data-dependent control flow.
158
182
# But these need to be expressed using control flow ops. For example,
159
183
# we can fix the control flow example above using the ``cond`` op, like so:
160
- # <!-- TODO link to docs about cond when it is out -->
184
+
185
+ # ..
186
+ # [TODO] link to docs about cond when it is out
161
187
162
188
from functorch .experimental .control_flow import cond
163
189
@@ -178,37 +204,36 @@ def false_fn(x):
178
204
# - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor.
179
205
# - The operands (i.e. ``[x]``) must be tensors.
180
206
# - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the
181
- # operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
207
+ # operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.).
182
208
# - Branch functions cannot mutate input or global variables.
183
209
# - Branch functions cannot access closure variables, except for ``self`` if the function is
184
210
# defined in the scope of a method.
185
211
186
- # <!-- NOTE map is not documented at the moment
187
-
188
212
######################################################################
189
- # We can also use ``map``, which applies a function across the first dimension
190
- # of the first tensor argument.
191
-
192
- # from functorch.experimental.control_flow import map
193
-
194
- # def map_example(xs):
195
- # def map_fn(x, const):
196
- # def true_fn(x ):
197
- # return x + const
198
- # def false_fn (x):
199
- # return x - const
200
- # return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
201
- # return control_flow.map(map_fn, xs, torch.tensor([2.0]))
202
-
203
- # exported_map_example= export(map_example, (torch.randn(4, 3), ))
204
- # inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
205
- # print( exported_map_example(inp ))
206
-
207
- # -->
213
+ # ..
214
+ # [NOTE] map is not documented at the moment
215
+ # We can also use ``map``, which applies a function across the first dimension
216
+ # of the first tensor argument.
217
+ #
218
+ # from functorch.experimental.control_flow import map
219
+ #
220
+ # def map_example(xs ):
221
+ # def map_fn(x, const):
222
+ # def true_fn (x):
223
+ # return x + const
224
+ # def false_fn(x):
225
+ # return x - const
226
+ # return control_flow.cond(x.sum() > 0, true_fn, false_fn, [x])
227
+ # return control_flow.map(map_fn, xs, torch.tensor([2.0] ))
228
+ #
229
+ # exported_map_example= export(map_example, (torch.randn(4, 3), ))
230
+ # inp = torch.cat((torch.ones(2, 3), -torch.ones(2, 3)))
231
+ # print(exported_map_example(inp))
208
232
209
233
######################################################################
210
234
# Constraints
211
235
# -----------
236
+ #
212
237
# .. warning::
213
238
#
214
239
# The constraints API is a prototype feature in PyTorch, included as a part of the torch.export release.
@@ -230,7 +255,9 @@ def false_fn(x):
230
255
# relax some of these constraints. We use ``torch.export.dynamic_dim`` to
231
256
# express shape constraints manually.
232
257
#
233
- # <!-- TODO link to doc of dynamic_dim when it is available -->
258
+ # ..
259
+ # [TODO] link to doc of dynamic_dim when it is available
260
+ #
234
261
# Using ``dynamic_dim`` on a tensor's dimension marks it as dynamic (i.e. unconstrained), and
235
262
# we can provide additional upper and lower bound shape constraints.
236
263
# The first argument of ``dynamic_dim`` is the tensor variable we wish
@@ -269,8 +296,8 @@ def constraints_example1(x):
269
296
tb .print_exc ()
270
297
271
298
######################################################################
272
- # Note that if our inputs to ``torch.export`` do not satisfy the constraints,
273
- # we get an error.
299
+ # Note that if our example inputs to ``torch.export`` do not satisfy the constraints,
300
+ # then we get an error.
274
301
275
302
constraints1_bad = [
276
303
dynamic_dim (inp1 , 0 ),
@@ -309,7 +336,9 @@ def constraints_example2(x, y):
309
336
310
337
######################################################################
311
338
# We can actually use ``torch.export`` to guide us as to which constraints
312
- # are necessary. We can do this by relaxing all constraints and letting ``torch.export``
339
+ # are necessary. We can do this by relaxing all constraints (recall that if we
340
+ # do not provide constraints for a dimension, the default behavior is to constrain
341
+ # to the exact shape value of the example input) and letting ``torch.export``
313
342
# error out.
314
343
315
344
inp4 = torch .randn (8 , 16 )
@@ -372,10 +401,7 @@ def specify_constraints(x, y):
372
401
# We can also constrain on individual values in the source code itself using
373
402
# ``constrain_as_value`` and ``constrain_as_size``. ``constrain_as_value`` specifies
374
403
# that a given integer value is expected to fall within the provided minimum/maximum bounds (inclusive).
375
- # If a bound is not provided, then it is assumed to be unbounded. ``constrain_as_size``
376
- # is similar to ``constrain_as_value``, except that it should be used on integer values that
377
- # will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
378
- # many operations have special behavior for tensors with a shape value of 0 or 1.
404
+ # If a bound is not provided, then it is assumed to be unbounded.
379
405
380
406
from torch .export import constrain_as_size , constrain_as_value
381
407
@@ -393,6 +419,11 @@ def constraints_example4(x, y):
393
419
except Exception :
394
420
tb .print_exc ()
395
421
422
+ ######################################################################
423
+ # ``constrain_as_size`` is similar to ``constrain_as_value``, except that it should be used on integer values that
424
+ # will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
425
+ # many operations have special behavior for tensors with a shape value of 0 or 1.
426
+
396
427
def constraints_example5 (x , y ):
397
428
b = y .item ()
398
429
constrain_as_size (b )
@@ -409,15 +440,18 @@ def constraints_example5(x, y):
409
440
######################################################################
410
441
# Custom Ops
411
442
# ----------
443
+ #
412
444
# ``torch.export`` can export PyTorch programs with custom operators.
413
445
#
414
- # NOTE: the API for registering custom ops is still under active development
415
- # and may change without notice.
446
+ # .. warning::
447
+ #
448
+ # The API for registering custom ops is still under active development
449
+ # and may change without notice.
416
450
#
417
451
# Currently, the steps to register a custom op for use by ``torch.export`` are:
418
452
#
419
453
# - Define the custom op using ``torch.library`` (`reference <https://pytorch.org/docs/main/library.html>`__)
420
- # as with any other custom op
454
+ # as with any other custom op
421
455
422
456
from torch .library import Library , impl
423
457
@@ -430,13 +464,15 @@ def custom_op(x):
430
464
print ("custom_op called!" )
431
465
return torch .relu (x )
432
466
467
+ ######################################################################
433
468
# - Define a ``"Meta"`` implementation of the custom op that returns an empty
434
- # tensor with the same shape as the expected output
469
+ # tensor with the same shape as the expected output
435
470
436
471
@impl (m , "custom_op" , "Meta" )
437
472
def custom_op_meta (x ):
438
473
return torch .empty_like (x )
439
474
475
+ ######################################################################
440
476
# - Call the custom op from the code you want to export using ``torch.ops``
441
477
442
478
def custom_op_example (x ):
@@ -445,24 +481,27 @@ def custom_op_example(x):
445
481
x = torch .cos (x )
446
482
return x
447
483
484
+ ######################################################################
448
485
# - Export the code as before
449
486
450
487
exported_custom_op_example = export (custom_op_example , (torch .randn (3 , 3 ),))
451
488
exported_custom_op_example .graph_module .print_readable ()
452
489
print (exported_custom_op_example (torch .randn (3 , 3 )))
453
490
491
+ ######################################################################
454
492
# Note in the above outputs that the custom op is included in the exported graph.
455
493
# And when we call the exported graph as a function, the original custom op is called,
456
494
# as evidenced by the ``print`` call.
457
495
458
496
######################################################################
459
497
# ExportDB
460
498
# --------
499
+ #
461
500
# ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement,
462
501
# there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to
463
502
# rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting
464
503
# if-statements using ``cond``.
465
-
504
+ #
466
505
# `ExportDB <https://pytorch.org/docs/main/generated/exportdb/index.html>`__ is the standard reference that documents
467
506
# supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each
468
507
# of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``.
@@ -481,17 +520,20 @@ def cond_predicate(x):
481
520
pred = x .dim () > 2 and x .shape [2 ] > 10
482
521
return cond (pred , lambda x : x .cos (), lambda y : y .sin (), [x ])
483
522
523
+ ######################################################################
484
524
# More generally, ExportDB can be used as a reference when one of the following occurs:
525
+ #
485
526
# 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features
486
527
# and you want to know if ``torch.export`` covers that feature.
487
528
# 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it.
488
-
529
+ #
489
530
# ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach
490
531
# out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``.
491
532
492
533
######################################################################
493
534
# Conclusion
494
535
# ----------
495
- # We introduced ``torch.export``, the new PyTorch 2.0 way to export single computation
536
+ #
537
+ # We introduced ``torch.export``, the new PyTorch 2 way to export single computation
496
538
# graphs from PyTorch programs. In particular, we demonstrate several code modifications
497
539
# and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph.
0 commit comments