Skip to content

Commit a290b2c

Browse files
committed
update torch.export nightly tutorial with decomp section and other small fixes
1 parent b5f19fe commit a290b2c

File tree

2 files changed

+238
-58
lines changed

2 files changed

+238
-58
lines changed

intermediate_source/_torch_export_nightly_tutorial.py

Lines changed: 98 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44
torch.export Nightly Tutorial
55
================
6-
**Author:** William Wen, Zhengxu Chen
6+
**Author:** William Wen, Zhengxu Chen, Angela Yi
77
"""
88

99
######################################################################
@@ -184,9 +184,6 @@ def bad4(x):
184184
# ``torch.export`` actually does support data-dependent control flow.
185185
# But these need to be expressed using control flow ops. For example,
186186
# we can fix the control flow example above using the ``cond`` op, like so:
187-
#
188-
# ..
189-
# [TODO] link to docs about ``cond`` when it is out
190187

191188
from functorch.experimental.control_flow import cond
192189

@@ -211,6 +208,8 @@ def false_fn(x):
211208
# - Branch functions cannot mutate input or global variables.
212209
# - Branch functions cannot access closure variables, except for ``self`` if the function is
213210
# defined in the scope of a method.
211+
#
212+
# For more details about ``cond``, check out the `documentation <https://pytorch.org/docs/main/cond.html>`__.
214213

215214
######################################################################
216215
# ..
@@ -261,12 +260,10 @@ def forward(self, x, y):
261260

262261
######################################################################
263262
# We can relax this constraint using the ``dynamic_shapes`` argument of
264-
# ``torch.export.export()``, which allows us to specify (using ``torch.export.Dim``)
263+
# ``torch.export.export()``, which allows us to specify, using ``torch.export.Dim``
264+
# (`documentation <https://pytorch.org/docs/main/export.html#torch.export.Dim>`__),
265265
# which dimensions of the input tensors are dynamic.
266266
#
267-
# ..
268-
# [TODO] link to doc of Dim when it is available
269-
#
270267
# For each tensor argument of the input callable, we can specify a mapping from the dimension
271268
# to a ``torch.export.Dim``.
272269
# A ``torch.export.Dim`` is essentially a named symbolic integer with optional
@@ -429,46 +426,6 @@ def suggested_fixes():
429426
print(exported_dynamic_shapes_example3.range_constraints)
430427
print(exported_dynamic_shapes_example3.equality_constraints)
431428

432-
######################################################################
433-
# We can also constrain on individual values in the source code itself using
434-
# ``constrain_as_value`` and ``constrain_as_size``. ``constrain_as_value`` specifies
435-
# that a given integer value is expected to fall within the provided minimum/maximum bounds (inclusive).
436-
# If a bound is not provided, then it is assumed to be unbounded.
437-
438-
from torch.export import constrain_as_size, constrain_as_value
439-
440-
def dynamic_shapes_example4(x, y):
441-
b = y.item()
442-
constrain_as_value(b, 3, 5)
443-
if b >= 3:
444-
return x.cos()
445-
return x.sin()
446-
447-
exported_dynamic_shapes_example4 = export(dynamic_shapes_example4, (torch.randn(3, 3), torch.tensor([4])))
448-
print(exported_dynamic_shapes_example4(torch.randn(3, 3), torch.tensor([5])))
449-
try:
450-
exported_dynamic_shapes_example4(torch.randn(3, 3), torch.tensor([2]))
451-
except Exception:
452-
tb.print_exc()
453-
454-
######################################################################
455-
# ``constrain_as_size`` is similar to ``constrain_as_value``, except that it should be used on integer values that
456-
# will be used to specify tensor shapes -- in particular, the value must not be 0 or 1 because
457-
# many operations have special behavior for tensors with a shape value of 0 or 1.
458-
459-
def dynamic_shapes_example5(x, y):
460-
b = y.item()
461-
constrain_as_size(b)
462-
z = torch.ones(b, 4)
463-
return x.sum() + z.sum()
464-
465-
exported_dynamic_shapes_example5 = export(dynamic_shapes_example5, (torch.randn(2, 2), torch.tensor([4])))
466-
print(exported_dynamic_shapes_example5(torch.randn(2, 2), torch.tensor([5])))
467-
try:
468-
exported_dynamic_shapes_example5(torch.randn(2, 2), torch.tensor([1]))
469-
except Exception:
470-
tb.print_exc()
471-
472429
######################################################################
473430
# Custom Ops
474431
# ----------
@@ -520,6 +477,99 @@ def custom_op_example(x):
520477
# Note in the above outputs that the custom op is included in the exported graph.
521478
# And when we call the exported graph as a function, the original custom op is called,
522479
# as evidenced by the ``print`` call.
480+
#
481+
# If you have a custom operator implemented in C++, please refer to
482+
# `this document <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz>`__
483+
# to make it compatible with ``torch.export``.
484+
485+
######################################################################
486+
# Decompositions
487+
# --------------
488+
#
489+
# The graph produced by ``torch.export`` by default returns a graph containing
490+
# only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000
491+
# operators, all of which are functional, that is, they do not
492+
# mutate or alias inputs. You can find a list of all ATen operators
493+
# `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__
494+
# and you can inspect if an operator is functional by checking
495+
# ``op._schema.is_mutable``, for example:
496+
497+
print(torch.ops.aten.add.Tensor._schema.is_mutable)
498+
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
499+
500+
######################################################################
501+
# By default, the environment in which you want to run the exported graph
502+
# should support all ~2000 of these operators.
503+
# However, you can use the following API on the exported program
504+
# if your specific environment is only able to support a subset of
505+
# the ~2000 operators.
506+
#
507+
# .. code:: python
508+
#
509+
# def run_decompositions(
510+
# self: ExportedProgram,
511+
# decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
512+
# ) -> ExportedProgram
513+
#
514+
# ``run_decompositions`` takes in a decomposition table, which is a mapping of
515+
# operators to a function specifying how to reduce, or decompose, that operator
516+
# into an equivalent sequence of other ATen operators.
517+
#
518+
# The default decomposition table for ``run_decompositions`` is the
519+
# `Core ATen decomposition table <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L252>`__
520+
# which will decompose the all ATen operators to the
521+
# `Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__
522+
# which consists of only ~180 operators.
523+
524+
class M(torch.nn.Module):
525+
def __init__(self):
526+
super().__init__()
527+
self.linear = torch.nn.Linear(3, 4)
528+
529+
def forward(self, x):
530+
return self.linear(x)
531+
532+
ep = export(M(), (torch.randn(2, 3),))
533+
print(ep.graph)
534+
535+
core_ir_ep = ep.run_decompositions()
536+
print(core_ir_ep.graph)
537+
538+
######################################################################
539+
# Notice that after running ``run_decompositions`` the
540+
# ``torch.ops.aten.t.default`` operator, which is not part of the Core ATen
541+
# Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part
542+
# of the Core ATen Opset.
543+
544+
######################################################################
545+
# Most ATen operators already have decompositions, which are located
546+
# `here <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/decompositions.py>`__.
547+
# If you would like to use some of these existing decomposition functions,
548+
# you can pass in a list of operators you would like to decompose to the
549+
# :func:`get_decompositions <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L191>`__
550+
# function, which will return a decomposition table using the pre-implemented
551+
# decompositions.
552+
553+
class M(torch.nn.Module):
554+
def __init__(self):
555+
super().__init__()
556+
self.linear = torch.nn.Linear(3, 4)
557+
558+
def forward(self, x):
559+
return self.linear(x)
560+
561+
ep = export(M(), (torch.randn(2, 3),))
562+
print(ep.graph)
563+
564+
from torch._decomp import get_decompositions
565+
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
566+
core_ir_ep = ep.run_decompositions(decomp_table)
567+
print(core_ir_ep.graph)
568+
569+
######################################################################
570+
# If there is no existing decomposition function for an ATen operator that you would
571+
# like to decompose, feel free to send a pull request into PyTorch
572+
# implementing the decomposition!
523573

524574
######################################################################
525575
# ExportDB

intermediate_source/torch_export_nightly_tutorial.rst

Lines changed: 140 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ so that it can be ran as a ``torch.nn.Module``.
116116
def forward(self, arg0_1: f32[10, 100], arg1_1: f32[10], arg2_1: f32[8, 100], arg3_1: f32[8, 100]):
117117
# File: torch_export_nightly_tutorial.py:69, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
118118
add: f32[8, 100] = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
119-
permute: f32[100, 10] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
120-
addmm: f32[8, 10] = torch.ops.aten.addmm.default(arg1_1, add, permute); arg1_1 = add = permute = None
119+
t: f32[100, 10] = torch.ops.aten.t.default(arg0_1); arg0_1 = None
120+
addmm: f32[8, 10] = torch.ops.aten.addmm.default(arg1_1, add, t); arg1_1 = add = t = None
121121
relu: f32[8, 10] = torch.ops.aten.relu.default(addmm); addmm = None
122122
return (relu,)
123123
124-
Graph signature: ExportGraphSignature(parameters=['lin.weight', 'lin.bias'], buffers=[], user_inputs=['arg2_1', 'arg3_1'], user_outputs=['relu'], inputs_to_parameters={'arg0_1': 'lin.weight', 'arg1_1': 'lin.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
124+
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='lin.weight'), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='lin.bias'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg2_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu'), target=None)])
125125
Range constraints: {}
126126
Equality constraints: []
127127
@@ -131,13 +131,11 @@ so that it can be ran as a ``torch.nn.Module``.
131131
132132
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
133133
add = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
134-
permute = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
135-
addmm = torch.ops.aten.addmm.default(arg1_1, add, permute); arg1_1 = add = permute = None
134+
t = torch.ops.aten.t.default(arg0_1); arg0_1 = None
135+
addmm = torch.ops.aten.addmm.default(arg1_1, add, t); arg1_1 = add = t = None
136136
relu = torch.ops.aten.relu.default(addmm); addmm = None
137137
return (relu,)
138138
139-
# To see more debug info, please use `graph_module.print_readable()`
140-
141139
The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
142140
and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True)``
143141
is represented in the printed code by ``torch.ops.aten.relu.default``, which does not mutate.
@@ -269,9 +267,6 @@ Control Flow Ops
269267
But these need to be expressed using control flow ops. For example,
270268
we can fix the control flow example above using the ``cond`` op, like so:
271269
272-
..
273-
[TODO] link to docs about ``cond`` when it is out
274-
275270
.. code-block:: python
276271
277272
from functorch.experimental.control_flow import cond
@@ -306,6 +301,8 @@ There are limitations to ``cond`` that one should be aware of:
306301
- Branch functions cannot access closure variables, except for ``self`` if the function is
307302
defined in the scope of a method.
308303
304+
For more details about ``cond``, check out the `documentation <https://pytorch.org/docs/main/cond.html>`__.
305+
309306
..
310307
[NOTE] map is not documented at the moment
311308
We can also use ``map``, which applies a function across the first dimension
@@ -683,6 +680,139 @@ Note in the above outputs that the custom op is included in the exported graph.
683680
And when we call the exported graph as a function, the original custom op is called,
684681
as evidenced by the ``print`` call.
685682
683+
If you have a custom operator implemented in C++, please refer to
684+
`this document <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz>`__
685+
to make it compatible with ``torch.export``.
686+
687+
Decompositions
688+
--------------
689+
690+
The graph produced by ``torch.export`` by default returns a graph containing
691+
only functional ATen operators. This functional ATen operator set (or "opset") contains around 2000
692+
operators, all of which are functional, that is, they do not
693+
mutate or alias inputs. You can find a list of all ATen operators
694+
`here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__
695+
and you can inspect if an operator is functional by checking
696+
``op._schema.is_mutable``, for example:
697+
698+
.. code-block:: python
699+
700+
print(torch.ops.aten.add.Tensor._schema.is_mutable)
701+
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
702+
703+
.. code-block:: bash
704+
705+
False
706+
True
707+
708+
By default, the environment in which you want to run the exported graph
709+
should support all ~2000 of these operators.
710+
However, you can use the following API on the exported program
711+
if your specific environment is only able to support a subset of
712+
the ~2000 operators.
713+
714+
.. code-block:: python
715+
716+
def run_decompositions(
717+
self: ExportedProgram,
718+
decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
719+
) -> ExportedProgram
720+
721+
``run_decompositions`` takes in a decomposition table, which is a mapping of
722+
operators to a function specifying how to reduce, or decompose, that operator
723+
into an equivalent sequence of other ATen operators.
724+
725+
The default decomposition table for ``run_decompositions`` is the
726+
`Core ATen decomposition table <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L252>`__
727+
which will decompose the all ATen operators to the
728+
`Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__
729+
which consists of only ~180 operators.
730+
731+
.. code-block:: python
732+
733+
class M(torch.nn.Module):
734+
def __init__(self):
735+
super().__init__()
736+
self.linear = torch.nn.Linear(3, 4)
737+
738+
def forward(self, x):
739+
return self.linear(x)
740+
741+
ep = export(M(), (torch.randn(2, 3),))
742+
print(ep.graph)
743+
744+
core_ir_ep = ep.run_decompositions()
745+
print(core_ir_ep.graph)
746+
747+
.. code-block:: bash
748+
749+
graph():
750+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
751+
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
752+
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
753+
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
754+
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
755+
return (addmm,)
756+
graph():
757+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
758+
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
759+
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
760+
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
761+
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
762+
return (addmm,)
763+
764+
Notice that after running ``run_decompositions`` the
765+
``torch.ops.aten.t.default`` operator, which is not part of the Core ATen
766+
Opset, has been replaced with ``torch.ops.aten.permute.default`` which is part
767+
of the Core ATen Opset.
768+
769+
Most ATen operators already have decompositions, which are located
770+
`here <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/decompositions.py>`__.
771+
If you would like to use some of these existing decomposition functions,
772+
you can pass in a list of operators you would like to decompose to the
773+
:func:`get_decompositions <https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L191>`__
774+
function, which will return a decomposition table using the pre-implemented
775+
decompositions.
776+
777+
.. code-block:: python
778+
779+
class M(torch.nn.Module):
780+
def __init__(self):
781+
super().__init__()
782+
self.linear = torch.nn.Linear(3, 4)
783+
784+
def forward(self, x):
785+
return self.linear(x)
786+
787+
ep = export(M(), (torch.randn(2, 3),))
788+
print(ep.graph)
789+
790+
from torch._decomp import get_decompositions
791+
decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
792+
core_ir_ep = ep.run_decompositions(decomp_table)
793+
print(core_ir_ep.graph)
794+
795+
.. code-block:: bash
796+
797+
graph():
798+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
799+
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
800+
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
801+
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
802+
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
803+
return (addmm,)
804+
graph():
805+
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
806+
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
807+
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
808+
%permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
809+
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
810+
return (addmm,)
811+
812+
If there is no existing decomposition function for an ATen operator that you would
813+
like to decompose, feel free to send a pull request into PyTorch
814+
implementing the decomposition!
815+
686816
ExportDB
687817
--------
688818

0 commit comments

Comments
 (0)