48
48
#
49
49
# ``torch.export`` traces the tensor computation graph from calling ``f(*args, **kwargs)``
50
50
# and wraps it in an ``ExportedProgram``, which can be serialized or executed later with
51
- # different inputs. Note that while the output ``ExportedGraph`` is callable, it is not a
52
- # ``torch.nn.Module``. We will detail the ``constraints`` argument later in the tutorial.
51
+ # different inputs. Note that while the output ``ExportedGraph`` is callable and can be
52
+ # called in the same way as the original input callable, it is not a ``torch.nn.Module``.
53
+ # We will detail the ``constraints`` argument later in the tutorial.
53
54
54
55
import torch
55
56
from torch .export import export
@@ -79,11 +80,9 @@ def forward(self, x, y):
79
80
#
80
81
# The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute
81
82
# so that it can be ran as a ``torch.nn.Module``.
82
- # We can use ``graph_module``'s ``print_readable``` to print a Python code representation
83
- # of ``graph``:
84
83
85
84
print (exported_mod )
86
- exported_mod .graph_module . print_readable ( )
85
+ print ( exported_mod .graph_module )
87
86
88
87
######################################################################
89
88
# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
@@ -98,8 +97,6 @@ def forward(self, x, y):
98
97
# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
99
98
100
99
print (exported_mod .graph_signature )
101
- print (exported_mod .range_constraints )
102
- print (exported_mod .equality_constraints )
103
100
104
101
######################################################################
105
102
# See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__
@@ -238,11 +235,22 @@ def false_fn(x):
238
235
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
239
236
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
240
237
# example inputs given to the initial ``torch.export`` call.
241
- # If we try to run the first ``ExportedProgram`` example with a tensor
238
+ # If we try to run the ``ExportedProgram`` in the example below with a tensor
242
239
# with a different shape, we get an error:
243
240
241
+ class MyModule2 (torch .nn .Module ):
242
+ def __init__ (self ):
243
+ super ().__init__ ()
244
+ self .lin = torch .nn .Linear (100 , 10 )
245
+
246
+ def forward (self , x , y ):
247
+ return torch .nn .functional .relu (self .lin (x + y ), inplace = True )
248
+
249
+ mod2 = MyModule2 ()
250
+ exported_mod2 = export (mod2 , (torch .randn (8 , 100 ), torch .randn (8 , 100 )))
251
+
244
252
try :
245
- exported_mod (torch .randn (10 , 100 ), torch .randn (10 , 100 ))
253
+ exported_mod2 (torch .randn (10 , 100 ), torch .randn (10 , 100 ))
246
254
except Exception :
247
255
tb .print_exc ()
248
256
@@ -355,11 +363,13 @@ def constraints_example3(x, y):
355
363
except Exception :
356
364
tb .print_exc ()
357
365
366
+ ######################################################################
367
+ # We can see that the error message suggests to us to use some additional code
368
+ # to specify the necessary constraints. Let us use that code (exact code may differ slightly):
369
+
358
370
def specify_constraints (x , y ):
359
371
return [
360
372
# x:
361
- dynamic_dim (x , 0 ),
362
- dynamic_dim (x , 1 ),
363
373
dynamic_dim (x , 0 ) <= 16 ,
364
374
365
375
# y:
0 commit comments