Skip to content

Commit 21127d7

Browse files
committed
more small changes
1 parent 093a364 commit 21127d7

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

intermediate_source/torch_export_tutorial.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
#
4949
# ``torch.export`` traces the tensor computation graph from calling ``f(*args, **kwargs)``
5050
# and wraps it in an ``ExportedProgram``, which can be serialized or executed later with
51-
# different inputs. Note that while the output ``ExportedGraph`` is callable, it is not a
52-
# ``torch.nn.Module``. We will detail the ``constraints`` argument later in the tutorial.
51+
# different inputs. Note that while the output ``ExportedGraph`` is callable and can be
52+
# called in the same way as the original input callable, it is not a ``torch.nn.Module``.
53+
# We will detail the ``constraints`` argument later in the tutorial.
5354

5455
import torch
5556
from torch.export import export
@@ -79,11 +80,9 @@ def forward(self, x, y):
7980
#
8081
# The ``graph_module`` attribute is the ``GraphModule`` that wraps the ``graph`` attribute
8182
# so that it can be ran as a ``torch.nn.Module``.
82-
# We can use ``graph_module``'s ``print_readable``` to print a Python code representation
83-
# of ``graph``:
8483

8584
print(exported_mod)
86-
exported_mod.graph_module.print_readable()
85+
print(exported_mod.graph_module)
8786

8887
######################################################################
8988
# The printed code shows that FX graph only contains ATen-level ops (such as ``torch.ops.aten``)
@@ -98,8 +97,6 @@ def forward(self, x, y):
9897
# - ``range_constraints`` and ``equality_constraints`` -- constraints, covered later
9998

10099
print(exported_mod.graph_signature)
101-
print(exported_mod.range_constraints)
102-
print(exported_mod.equality_constraints)
103100

104101
######################################################################
105102
# See the ``torch.export`` `documentation <https://pytorch.org/docs/main/export.html#torch.export.export>`__
@@ -238,11 +235,22 @@ def false_fn(x):
238235
# Ops can have different specializations/behaviors for different tensor shapes, so by default,
239236
# ``torch.export`` requires inputs to ``ExportedProgram`` to have the same shape as the respective
240237
# example inputs given to the initial ``torch.export`` call.
241-
# If we try to run the first ``ExportedProgram`` example with a tensor
238+
# If we try to run the ``ExportedProgram`` in the example below with a tensor
242239
# with a different shape, we get an error:
243240

241+
class MyModule2(torch.nn.Module):
242+
def __init__(self):
243+
super().__init__()
244+
self.lin = torch.nn.Linear(100, 10)
245+
246+
def forward(self, x, y):
247+
return torch.nn.functional.relu(self.lin(x + y), inplace=True)
248+
249+
mod2 = MyModule2()
250+
exported_mod2 = export(mod2, (torch.randn(8, 100), torch.randn(8, 100)))
251+
244252
try:
245-
exported_mod(torch.randn(10, 100), torch.randn(10, 100))
253+
exported_mod2(torch.randn(10, 100), torch.randn(10, 100))
246254
except Exception:
247255
tb.print_exc()
248256

@@ -355,11 +363,13 @@ def constraints_example3(x, y):
355363
except Exception:
356364
tb.print_exc()
357365

366+
######################################################################
367+
# We can see that the error message suggests to us to use some additional code
368+
# to specify the necessary constraints. Let us use that code (exact code may differ slightly):
369+
358370
def specify_constraints(x, y):
359371
return [
360372
# x:
361-
dynamic_dim(x, 0),
362-
dynamic_dim(x, 1),
363373
dynamic_dim(x, 0) <= 16,
364374

365375
# y:

0 commit comments

Comments
 (0)