@@ -116,12 +116,12 @@ so that it can be ran as a ``torch.nn.Module``.
116
116
def forward(self, arg0_1: f32[10, 100], arg1_1: f32[10], arg2_1: f32[8, 100], arg3_1: f32[8, 100]):
117
117
# File: torch_export_nightly_tutorial.py:69, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
118
118
add: f32[8, 100] = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
119
- permute : f32[100, 10] = torch.ops.aten.permute .default(arg0_1, [1, 0] ); arg0_1 = None
120
- addmm: f32[8, 10] = torch.ops.aten.addmm.default(arg1_1, add, permute ); arg1_1 = add = permute = None
119
+ t : f32[100, 10] = torch.ops.aten.t .default(arg0_1); arg0_1 = None
120
+ addmm: f32[8, 10] = torch.ops.aten.addmm.default(arg1_1, add, t ); arg1_1 = add = t = None
121
121
relu: f32[8, 10] = torch.ops.aten.relu.default(addmm); addmm = None
122
122
return (relu,)
123
123
124
- Graph signature: ExportGraphSignature(parameters=[ ' lin.weight ' , ' lin.bias ' ], buffers=[], user_inputs=[ ' arg2_1 ' , ' arg3_1 ' ], user_outputs=[ ' relu ' ], inputs_to_parameters={ ' arg0_1 ' : ' lin.weight ' , ' arg1_1 ' : ' lin.bias ' }, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token =None)
124
+ Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind= < InputKind.PARAMETER: 2> , arg=TensorArgument(name= ' arg0_1 ' ), target= ' lin.weight ' ), InputSpec(kind= < InputKind.PARAMETER: 2> , arg=TensorArgument(name= ' arg1_1 ' ), target= ' lin.bias ' ), InputSpec(kind= < InputKind.USER_INPUT: 1> , arg=TensorArgument(name= ' arg2_1 ' ), target=None), InputSpec(kind= < InputKind.USER_INPUT: 1> , arg=TensorArgument(name= ' arg3_1 ' ), target=None)], output_specs=[OutputSpec(kind= < OutputKind.USER_OUTPUT: 1> , arg=TensorArgument(name= ' relu ' ), target =None)] )
125
125
Range constraints: {}
126
126
Equality constraints: []
127
127
@@ -131,13 +131,11 @@ so that it can be ran as a ``torch.nn.Module``.
131
131
132
132
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
133
133
add = torch.ops.aten.add.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None
134
- permute = torch.ops.aten.permute .default(arg0_1, [1, 0] ); arg0_1 = None
135
- addmm = torch.ops.aten.addmm.default(arg1_1, add, permute ); arg1_1 = add = permute = None
134
+ t = torch.ops.aten.t .default(arg0_1); arg0_1 = None
135
+ addmm = torch.ops.aten.addmm.default(arg1_1, add, t ); arg1_1 = add = t = None
136
136
relu = torch.ops.aten.relu.default(addmm); addmm = None
137
137
return (relu,)
138
138
139
- # To see more debug info, please use `graph_module.print_readable()`
140
-
141
139
The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten ``)
142
140
and that mutations were removed. For example, the mutating op ``torch.nn.functional.relu(..., inplace=True) ``
143
141
is represented in the printed code by ``torch.ops.aten.relu.default ``, which does not mutate.
@@ -269,9 +267,6 @@ Control Flow Ops
269
267
But these need to be expressed using control flow ops. For example,
270
268
we can fix the control flow example above using the ` ` cond` ` op, like so:
271
269
272
- ..
273
- [TODO] link to docs about ` ` cond` ` when it is out
274
-
275
270
.. code-block:: python
276
271
277
272
from functorch.experimental.control_flow import cond
@@ -306,6 +301,8 @@ There are limitations to ``cond`` that one should be aware of:
306
301
- Branch functions cannot access closure variables, except for ` ` self` ` if the function is
307
302
defined in the scope of a method.
308
303
304
+ For more details about ` ` cond` ` , check out the ` documentation < https://pytorch.org/docs/main/cond.html> ` __.
305
+
309
306
..
310
307
[NOTE] map is not documented at the moment
311
308
We can also use ` ` map` ` , which applies a function across the first dimension
@@ -683,6 +680,139 @@ Note in the above outputs that the custom op is included in the exported graph.
683
680
And when we call the exported graph as a function, the original custom op is called,
684
681
as evidenced by the ` ` print` ` call.
685
682
683
+ If you have a custom operator implemented in C++, please refer to
684
+ ` this document < https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz> ` __
685
+ to make it compatible with ` ` torch.export` ` .
686
+
687
+ Decompositions
688
+ --------------
689
+
690
+ The graph produced by ` ` torch.export` ` by default returns a graph containing
691
+ only functional ATen operators. This functional ATen operator set (or " opset" ) contains around 2000
692
+ operators, all of which are functional, that is, they do not
693
+ mutate or alias inputs. You can find a list of all ATen operators
694
+ ` here < https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml> ` __
695
+ and you can inspect if an operator is functional by checking
696
+ ` ` op._schema.is_mutable` ` , for example:
697
+
698
+ .. code-block:: python
699
+
700
+ print(torch.ops.aten.add.Tensor._schema.is_mutable)
701
+ print(torch.ops.aten.add_.Tensor._schema.is_mutable)
702
+
703
+ .. code-block:: bash
704
+
705
+ False
706
+ True
707
+
708
+ By default, the environment in which you want to run the exported graph
709
+ should support all ~ 2000 of these operators.
710
+ However, you can use the following API on the exported program
711
+ if your specific environment is only able to support a subset of
712
+ the ~ 2000 operators.
713
+
714
+ .. code-block:: python
715
+
716
+ def run_decompositions(
717
+ self: ExportedProgram,
718
+ decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]]
719
+ ) -> ExportedProgram
720
+
721
+ ` ` run_decompositions` ` takes in a decomposition table, which is a mapping of
722
+ operators to a function specifying how to reduce, or decompose, that operator
723
+ into an equivalent sequence of other ATen operators.
724
+
725
+ The default decomposition table for ` ` run_decompositions` ` is the
726
+ ` Core ATen decomposition table < https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L252> ` __
727
+ which will decompose the all ATen operators to the
728
+ ` Core ATen Operator Set < https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir> ` __
729
+ which consists of only ~ 180 operators.
730
+
731
+ .. code-block:: python
732
+
733
+ class M(torch.nn.Module):
734
+ def __init__(self):
735
+ super().__init__ ()
736
+ self.linear = torch.nn.Linear(3, 4)
737
+
738
+ def forward(self, x):
739
+ return self.linear(x)
740
+
741
+ ep = export(M (), (torch.randn(2, 3),))
742
+ print(ep.graph)
743
+
744
+ core_ir_ep = ep.run_decompositions ()
745
+ print(core_ir_ep.graph)
746
+
747
+ .. code-block:: bash
748
+
749
+ graph ():
750
+ %arg0_1 : [num_users= 1] = placeholder[target= arg0_1]
751
+ %arg1_1 : [num_users= 1] = placeholder[target= arg1_1]
752
+ %arg2_1 : [num_users= 1] = placeholder[target= arg2_1]
753
+ %t : [num_users= 1] = call_function[target= torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
754
+ %addmm : [num_users= 1] = call_function[target= torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
755
+ return (addmm,)
756
+ graph ():
757
+ %arg0_1 : [num_users= 1] = placeholder[target= arg0_1]
758
+ %arg1_1 : [num_users= 1] = placeholder[target= arg1_1]
759
+ %arg2_1 : [num_users= 1] = placeholder[target= arg2_1]
760
+ %permute : [num_users= 1] = call_function[target= torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
761
+ %addmm : [num_users= 1] = call_function[target= torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
762
+ return (addmm,)
763
+
764
+ Notice that after running ` ` run_decompositions` ` the
765
+ ` ` torch.ops.aten.t.default` ` operator, which is not part of the Core ATen
766
+ Opset, has been replaced with ` ` torch.ops.aten.permute.default` ` which is part
767
+ of the Core ATen Opset.
768
+
769
+ Most ATen operators already have decompositions, which are located
770
+ ` here < https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/decompositions.py> ` __.
771
+ If you would like to use some of these existing decomposition functions,
772
+ you can pass in a list of operators you would like to decompose to the
773
+ :func:` get_decompositions < https://github.com/pytorch/pytorch/blob/b460c3089367f3fadd40aa2cb3808ee370aa61e1/torch/_decomp/__init__.py#L191> ` __
774
+ function, which will return a decomposition table using the pre-implemented
775
+ decompositions.
776
+
777
+ .. code-block:: python
778
+
779
+ class M(torch.nn.Module):
780
+ def __init__(self):
781
+ super().__init__ ()
782
+ self.linear = torch.nn.Linear(3, 4)
783
+
784
+ def forward(self, x):
785
+ return self.linear(x)
786
+
787
+ ep = export(M (), (torch.randn(2, 3),))
788
+ print(ep.graph)
789
+
790
+ from torch._decomp import get_decompositions
791
+ decomp_table = get_decompositions([torch.ops.aten.t.default, torch.ops.aten.transpose.int])
792
+ core_ir_ep = ep.run_decompositions(decomp_table)
793
+ print(core_ir_ep.graph)
794
+
795
+ .. code-block:: bash
796
+
797
+ graph ():
798
+ %arg0_1 : [num_users= 1] = placeholder[target= arg0_1]
799
+ %arg1_1 : [num_users= 1] = placeholder[target= arg1_1]
800
+ %arg2_1 : [num_users= 1] = placeholder[target= arg2_1]
801
+ %t : [num_users= 1] = call_function[target= torch.ops.aten.t.default](args = (%arg0_1,), kwargs = {})
802
+ %addmm : [num_users= 1] = call_function[target= torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %t), kwargs = {})
803
+ return (addmm,)
804
+ graph ():
805
+ %arg0_1 : [num_users= 1] = placeholder[target= arg0_1]
806
+ %arg1_1 : [num_users= 1] = placeholder[target= arg1_1]
807
+ %arg2_1 : [num_users= 1] = placeholder[target= arg2_1]
808
+ %permute : [num_users= 1] = call_function[target= torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
809
+ %addmm : [num_users= 1] = call_function[target= torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
810
+ return (addmm,)
811
+
812
+ If there is no existing decomposition function for an ATen operator that you would
813
+ like to decompose, feel free to send a pull request into PyTorch
814
+ implementing the decomposition!
815
+
686
816
ExportDB
687
817
--------
688
818
0 commit comments